diff --git a/mmdet/apis/ote/apis/detection/config_utils.py b/mmdet/apis/ote/apis/detection/config_utils.py index b886ee952e8..5ac1369b3ae 100644 --- a/mmdet/apis/ote/apis/detection/config_utils.py +++ b/mmdet/apis/ote/apis/detection/config_utils.py @@ -190,9 +190,10 @@ def config_from_string(config_string: str) -> Config: return Config.fromfile(temp_file.name) -def save_config_to_file(config: Config): +def save_config_to_file(config: Config, filepath: Optional[str] = None): """ Dump the full config to a file. Filename is 'config.py', it is saved in the current work_dir. """ - filepath = os.path.join(config.work_dir, 'config.py') + if filepath is None: + filepath = os.path.join(config.work_dir, 'config.py') config_string = config_to_string(config) with open(filepath, 'w') as f: f.write(config_string) diff --git a/mmdet/apis/ote/apis/detection/configuration.py b/mmdet/apis/ote/apis/detection/configuration.py index 661c4b428b5..261dce0654c 100644 --- a/mmdet/apis/ote/apis/detection/configuration.py +++ b/mmdet/apis/ote/apis/detection/configuration.py @@ -155,7 +155,22 @@ class __POTParameter(ParameterGroup): description="Quantization preset that defines quantization scheme", editable=True, visible_in_ui=True) + + @attrs + class __DebugParameters(ParameterGroup): + header = string_attribute("Debugging Parameters") + description = header + + enable_debug_dump = configurable_boolean( + default_value=True, + header="Enable data dumps for debugging", + description="Enable data dumps for debugging", + affects_outcome_of=ModelLifecycle.NONE + ) + + learning_parameters = add_parameter_group(__LearningParameters) postprocessing = add_parameter_group(__Postprocessing) nncf_optimization = add_parameter_group(__NNCFOptimization) pot_parameters = add_parameter_group(__POTParameter) + debug_parameters = add_parameter_group(__DebugParameters) diff --git a/mmdet/apis/ote/apis/detection/configuration.yaml b/mmdet/apis/ote/apis/detection/configuration.yaml index edf68c1035f..54fada9c094 100644 --- a/mmdet/apis/ote/apis/detection/configuration.yaml +++ b/mmdet/apis/ote/apis/detection/configuration.yaml @@ -1,3 +1,23 @@ +debug_parameters: + description: Debugging Parameters + enable_debug_dump: + affects_outcome_of: NONE + default_value: true + description: Enable data dumps for debugging + editable: true + header: Enable data dumps for debugging + type: BOOLEAN + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + value: true + visible_in_ui: true + warning: null + header: Debugging Parameters + type: PARAMETER_GROUP + visible_in_ui: true description: Configuration for an object detection task header: Configuration for an object detection task learning_parameters: diff --git a/mmdet/apis/ote/apis/detection/debug.py b/mmdet/apis/ote/apis/detection/debug.py new file mode 100644 index 00000000000..93719d7e263 --- /dev/null +++ b/mmdet/apis/ote/apis/detection/debug.py @@ -0,0 +1,216 @@ +import datetime +import logging +import os +import pickle +import socket +from functools import wraps +from typing import Dict, Any + +from ote_sdk.entities.dataset_item import DatasetItemEntity +from ote_sdk.entities.datasets import DatasetEntity +from ote_sdk.entities.image import Image + +from mmdet.utils.logger import get_root_logger + + +logger = get_root_logger() + + +def get_dump_file_path(): + full_path = os.path.join( + '/NOUS' if os.path.exists('/NOUS') else '/tmp', + 'debug_dumps', + socket.gethostname(), + datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + '.pkl') + return full_path + + +def debug_trace(func): + @wraps(func) + def wrapped_function(self, *args, **kwargs): + class_name = self.__class__.__name__ + func_name = func.__name__ + if self._hyperparams.debug_parameters.enable_debug_dump: + dump_dict = { + 'class_name': class_name, + 'entrypoint': func_name, + 'task': self, + } + if func_name not in debug_trace_registry: + raise ValueError(f'Debug tracing is not implemented for {func_name} method.') + dump_dict['arguments'] = debug_trace_registry[func_name](self, *args, **kwargs) + logger.warning(f'Saving debug dump for {class_name}.{func_name} call to {self._debug_dump_file_path}') + os.makedirs(os.path.dirname(self._debug_dump_file_path), exist_ok=True) + with open(self._debug_dump_file_path, 'ab') as fp: + pickle.dump(dump_dict, fp) + return func(self, *args, **kwargs) + return wrapped_function + + +def infer_debug_trace(self, dataset, inference_parameters=None): + return {'dataset': dump_dataset(dataset)} + + +def evaluate_debug_trace(self, output_result_set, evaluation_metric=None): + return { + 'output_resultset': { + 'purpose': output_result_set.purpose, + 'ground_truth_dataset' : dump_dataset(output_result_set.ground_truth_dataset), + 'prediction_dataset' : dump_dataset(output_result_set.prediction_dataset) + }, + 'evaluation_metric': evaluation_metric, + } + + +def export_debug_trace(self, export_type, output_model): + return { + 'export_type': export_type + } + + +def train_debug_trace(self, dataset, output_model, train_parameters=None): + return { + 'dataset': dump_dataset(dataset), + 'train_parameters': None if train_parameters is None else {'resume': train_parameters.resume} + } + + +debug_trace_registry = { + 'infer': infer_debug_trace, + 'train': train_debug_trace, + 'evaluate': evaluate_debug_trace, + 'export': export_debug_trace, +} + + +def dump_dataset_item(item: DatasetItemEntity): + dump = { + 'subset': item.subset, + 'numpy': item.numpy, + 'roi': item.roi, + 'annotation_scene': item.annotation_scene + } + return dump + + +def load_dataset_item(dump: Dict[str, Any]): + return DatasetItemEntity( + media=Image(dump['numpy']), + annotation_scene=dump['annotation_scene'], + roi=dump['roi'], + subset=dump['subset']) + + +def dump_dataset(dataset: DatasetEntity): + dump = { + 'purpose': dataset.purpose, + 'items': list(dump_dataset_item(item) for item in dataset) + } + return dump + + +def load_dataset(dump: Dict[str, Any]): + return DatasetEntity( + items=[load_dataset_item(i) for i in dump['items']], + purpose=dump['purpose']) + + +if __name__ == '__main__': + import argparse + from ote_sdk.entities.model import ModelEntity, ModelStatus + from ote_sdk.entities.resultset import ResultSetEntity + + + def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('dump_path') + return parser.parse_args() + + + def main(): + args = parse_args() + assert os.path.exists(args.dump_path) + + output_model = None + train_dataset = None + + with open(args.dump_path, 'rb') as f: + while True: + print('reading dump record...') + logger = get_root_logger() + logger.setLevel(logging.ERROR) + try: + dump = pickle.load(f) + except EOFError: + print('no more records found in the dump file') + break + logger.setLevel(logging.INFO) + + task = dump['task'] + # Disable debug dump when replay another debug dump + task._task_environment.get_hyper_parameters().debug_parameters.enable_debug_dump = False + method_args = {} + + entrypoint = dump['entrypoint'] + print('*' * 80) + + print(f'{type(task)=}, {entrypoint=}') + print('=' * 80) + + while True: + action = input('[r]eplay, [s]kip or [q]uit : [r] ') + action = action.lower() + if action == '': + action = 'r' + if action not in {'r', 's', 'q'}: + continue + else: + break + + if action == 's': + print('skipping the step replay') + continue + if action == 'q': + print('quiting dump replay session') + exit(0) + + print('replaying the step') + + if entrypoint == 'train': + method_args['dataset'] = load_dataset(dump['arguments']['dataset']) + train_dataset = method_args['dataset'] + method_args['output_model'] = ModelEntity( + method_args['dataset'], + task._task_environment, + model_status=ModelStatus.NOT_READY) + output_model = method_args['output_model'] + method_args['train_parameters'] = None + elif entrypoint == 'infer': + method_args['dataset'] = load_dataset(dump['arguments']['dataset']) + method_args['inference_parameters'] = None + elif entrypoint == 'export': + method_args['output_model'] = ModelEntity( + train_dataset, + task._task_environment, + model_status=ModelStatus.NOT_READY) + output_model = method_args['output_model'] + method_args['export_type'] = dump['arguments']['export_type'] + elif entrypoint == 'evaluate': + output_model = ModelEntity( + DatasetEntity(), + task._task_environment, + model_status=ModelStatus.SUCCESS) + output_model.configuration.label_schema = task._task_environment.label_schema + method_args['output_result_set'] = ResultSetEntity( + model=output_model, + ground_truth_dataset=load_dataset(dump['arguments']['output_resultset']['ground_truth_dataset']), + prediction_dataset=load_dataset(dump['arguments']['output_resultset']['prediction_dataset']) + ) + method_args['evaluation_metric'] = dump['arguments']['evaluation_metric'] + else: + raise RuntimeError(f'Unknown {entrypoint=}') + + output = getattr(task, entrypoint)(**method_args) + print(f'\nOutput {type(output)=}\n\n\n\n') + + main() diff --git a/mmdet/apis/ote/apis/detection/inference_task.py b/mmdet/apis/ote/apis/detection/inference_task.py index 07f913bf0f9..5eabd8a0b6c 100644 --- a/mmdet/apis/ote/apis/detection/inference_task.py +++ b/mmdet/apis/ote/apis/detection/inference_task.py @@ -16,9 +16,9 @@ import io import os import shutil +import subprocess import tempfile import warnings -from subprocess import run from typing import List, Optional, Tuple import numpy as np @@ -42,8 +42,9 @@ from ote_sdk.usecases.tasks.interfaces.unload_interface import IUnload from mmdet.apis import export_model -from mmdet.apis.ote.apis.detection.config_utils import patch_config, prepare_for_testing, set_hyperparams +from mmdet.apis.ote.apis.detection.config_utils import patch_config, prepare_for_testing, save_config_to_file, set_hyperparams from mmdet.apis.ote.apis.detection.configuration import OTEDetectionConfig +from mmdet.apis.ote.apis.detection.debug import debug_trace, get_dump_file_path from mmdet.apis.ote.apis.detection.ote_utils import InferenceProgressCallback from mmdet.datasets import build_dataloader, build_dataset from mmdet.models import build_detector @@ -57,6 +58,7 @@ class OTEDetectionInferenceTask(IInferenceTask, IExportTask, IEvaluationTask, IUnload): _task_environment: TaskEnvironment + _debug_dump_file_path: str = get_dump_file_path() def __init__(self, task_environment: TaskEnvironment): """" @@ -64,11 +66,11 @@ def __init__(self, task_environment: TaskEnvironment): """ logger.info('Loading OTEDetectionTask') - print('ENVIRONMENT:') + logger.info('ENVIRONMENT:') for name, val in collect_env().items(): - print(f'{name}: {val}') - print('pip list:') - run('pip list', shell=True, check=True) + logger.info(f'{name}: {val}') + logger.info('pip list:') + logger.info(subprocess.check_output(['pip', 'list'], universal_newlines=True)) self._task_environment = task_environment @@ -102,10 +104,12 @@ def __init__(self, task_environment: TaskEnvironment): self._should_stop = False logger.info('Task initialization completed') + @property def _hyperparams(self): return self._task_environment.get_hyper_parameters(OTEDetectionConfig) + def _load_model(self, model: ModelEntity): if model is not None: # If a model has been trained and saved for the task already, create empty model and load weights here @@ -160,6 +164,52 @@ def _create_model(config: Config, from_scratch: bool = False): return model + def __getstate__(self): + from ote_sdk.configuration.helper import convert + + model = { + 'weights': self._model.state_dict(), + 'config': self._config, + 'confidence_threshold': self.confidence_threshold, + } + environment = { + 'model_template': self._task_environment.model_template, + 'hyperparams': convert(self._hyperparams, str), + 'label_schema': self._task_environment.label_schema, + } + return { + 'environment': environment, + 'model': model, + } + + + def __setstate__(self, state): + from ote_sdk.configuration.helper import create + from dataclasses import asdict + import yaml + + with tempfile.TemporaryDirectory() as tmpdir: + model_template = state['environment']['model_template'] + save_config_to_file(state['model']['config'], os.path.join(tmpdir, 'model.py')) + with open(os.path.join(tmpdir, 'template.yaml'), 'wt') as f: + yaml.dump(asdict(model_template), f) + model_template.model_template_path = os.path.join(tmpdir, 'template.yaml') + + hyperparams = create(state['environment']['hyperparams']) + label_schema = state['environment']['label_schema'] + environment = TaskEnvironment( + model_template=model_template, + model=None, + hyper_parameters=hyperparams, + label_schema=label_schema, + ) + self.__init__(environment) + + self._model.load_state_dict(state['model']['weights']) + self._config = state['model']['config'] + self.confidence_threshold = state['model']['confidence_threshold'] + + def _add_predictions_to_dataset(self, prediction_results, dataset, confidence_threshold=0.0): """ Loop over dataset again to assign predictions. Convert from MMDetection format to OTE format. """ for dataset_item, (all_bboxes, fmap) in zip(dataset, prediction_results): @@ -193,6 +243,7 @@ def _add_predictions_to_dataset(self, prediction_results, dataset, confidence_th dataset_item.append_metadata_item(active_score) + @debug_trace def infer(self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None) -> DatasetEntity: """ Analyzes a dataset using the latest inference model. """ @@ -272,6 +323,7 @@ def dummy_dump_features_hook(mod, inp, out): return eval_predictions, metric + @debug_trace def evaluate(self, output_result_set: ResultSetEntity, evaluation_metric: Optional[str] = None): @@ -316,6 +368,8 @@ def unload(self): logger.warning(f"Done unloading. " f"Torch is still occupying {torch.cuda.memory_allocated()} bytes of GPU memory") + + @debug_trace def export(self, export_type: ExportType, output_model: ModelEntity): @@ -350,6 +404,7 @@ def export(self, raise RuntimeError('Optimization was unsuccessful.') from ex logger.info('Exporting completed') + def _delete_scratch_space(self): """ Remove model checkpoints and mmdet logs diff --git a/mmdet/apis/ote/apis/detection/train_task.py b/mmdet/apis/ote/apis/detection/train_task.py index 1b6a447e1fb..6b45f7449b1 100644 --- a/mmdet/apis/ote/apis/detection/train_task.py +++ b/mmdet/apis/ote/apis/detection/train_task.py @@ -34,6 +34,7 @@ from mmdet.apis import train_detector from mmdet.apis.ote.apis.detection.config_utils import prepare_for_training, set_hyperparams +from mmdet.apis.ote.apis.detection.debug import debug_trace from mmdet.apis.ote.apis.detection.inference_task import OTEDetectionInferenceTask from mmdet.apis.ote.apis.detection.ote_utils import TrainingProgressCallback from mmdet.apis.ote.extension.utils.hooks import OTELoggerHook @@ -70,6 +71,7 @@ def _generate_training_metrics(self, learning_curves, map) -> Optional[List[Metr return output + @debug_trace def train(self, dataset: DatasetEntity, output_model: ModelEntity, train_parameters: Optional[TrainParameters] = None): """ Trains a model on a dataset """ diff --git a/mmdet/apis/ote/sample/sample.py b/mmdet/apis/ote/sample/sample.py index 0491fcf5130..e36586bd3ef 100644 --- a/mmdet/apis/ote/sample/sample.py +++ b/mmdet/apis/ote/sample/sample.py @@ -118,6 +118,7 @@ def main(args): params.learning_parameters.num_iters = 5 params.learning_parameters.learning_rate_warmup_iters = 1 params.learning_parameters.batch_size = 2 + params.debug_parameters.enable_debug_dump = False logger.info('Setup environment') environment = TaskEnvironment(model=None, hyper_parameters=params, label_schema=labels_schema, model_template=model_template) diff --git a/tests/test_ote_api.py b/tests/test_ote_api.py index 77c39a7137d..13e4e497c5d 100644 --- a/tests/test_ote_api.py +++ b/tests/test_ote_api.py @@ -162,6 +162,7 @@ def setup_configurable_parameters(self, template_dir, num_iters=10): hyper_parameters.learning_parameters.num_iters = num_iters hyper_parameters.postprocessing.result_based_confidence_threshold = False hyper_parameters.postprocessing.confidence_threshold = 0.1 + hyper_parameters.debug_parameters.enable_debug_dump = False return hyper_parameters, model_template @e2e_pytest_api diff --git a/tests/test_ote_training.py b/tests/test_ote_training.py index c20a6d67106..9746fcc101c 100644 --- a/tests/test_ote_training.py +++ b/tests/test_ote_training.py @@ -243,6 +243,7 @@ def _run_ote_training(self, data_collector): logger.debug('Set hyperparameters') params = create(self.model_template.hyper_parameters.data) + params.debug_parameters.enable_debug_dump = False params.learning_parameters.num_iters = self.num_training_iters params.learning_parameters.batch_size = self.batch_size