diff --git a/dagshub/auth/token_auth.py b/dagshub/auth/token_auth.py index 31ec32ac..7ba3a70a 100644 --- a/dagshub/auth/token_auth.py +++ b/dagshub/auth/token_auth.py @@ -37,7 +37,7 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: def can_renegotiate(self): # Env var tokens cannot renegotiate, every other token type can - return not type(self._token) is EnvVarDagshubToken + return type(self._token) is not EnvVarDagshubToken def renegotiate_token(self): if not self._token_storage.is_valid_token(self._token, self._host): diff --git a/dagshub/data_engine/annotation/importer.py b/dagshub/data_engine/annotation/importer.py index c19212de..c4c86592 100644 --- a/dagshub/data_engine/annotation/importer.py +++ b/dagshub/data_engine/annotation/importer.py @@ -1,13 +1,21 @@ from difflib import SequenceMatcher from pathlib import Path, PurePosixPath, PurePath from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Literal, Optional, Union, Sequence, Mapping, Callable, List - -from dagshub_annotation_converter.converters.cvat import load_cvat_from_zip +from typing import TYPE_CHECKING, Dict, Literal, Optional, Union, Sequence, Mapping, Callable, List + +from dagshub_annotation_converter.converters.coco import load_coco_from_file +from dagshub_annotation_converter.converters.cvat import ( + load_cvat_from_fs, + load_cvat_from_zip, + load_cvat_from_xml_file, +) +from dagshub_annotation_converter.converters.mot import load_mot_from_dir, load_mot_from_fs, load_mot_from_zip from dagshub_annotation_converter.converters.yolo import load_yolo_from_fs +from dagshub_annotation_converter.converters.label_studio_video import video_ir_to_ls_video_tasks from dagshub_annotation_converter.formats.label_studio.task import LabelStudioTask from dagshub_annotation_converter.formats.yolo import YoloContext from dagshub_annotation_converter.ir.image.annotations.base import IRAnnotationBase +from dagshub_annotation_converter.ir.video import IRVideoBBoxAnnotation from dagshub.common.api import UserAPI from dagshub.common.api.repo import PathNotFoundError @@ -16,7 +24,7 @@ if TYPE_CHECKING: from dagshub.data_engine.model.datasource import Datasource -AnnotationType = Literal["yolo", "cvat"] +AnnotationType = Literal["yolo", "cvat", "coco", "mot", "cvat_video"] AnnotationLocation = Literal["repo", "disk"] @@ -57,6 +65,10 @@ def __init__( 'Add `yolo_type="bbox"|"segmentation"|pose"` to the arguments.' ) + @property + def is_video_format(self) -> bool: + return self.annotations_type in ("mot", "cvat_video") + def import_annotations(self) -> Mapping[str, Sequence[IRAnnotationBase]]: # Double check that the annotation file exists if self.load_from == "disk": @@ -84,15 +96,130 @@ def import_annotations(self) -> Mapping[str, Sequence[IRAnnotationBase]]: annotation_type=self.additional_args["yolo_type"], meta_file=annotations_file ) elif self.annotations_type == "cvat": - annotation_dict = load_cvat_from_zip(annotations_file) + if annotations_file.is_dir(): + annotation_dict = self._flatten_cvat_fs_annotations(load_cvat_from_fs(annotations_file)) + else: + result = load_cvat_from_zip(annotations_file) + if self._is_video_annotation_dict(result): + annotation_dict = self._flatten_video_annotations(result) + else: + annotation_dict = result + elif self.annotations_type == "coco": + annotation_dict, _ = load_coco_from_file(annotations_file) + elif self.annotations_type == "mot": + mot_kwargs = {} + if "image_width" in self.additional_args: + mot_kwargs["image_width"] = self.additional_args["image_width"] + if "image_height" in self.additional_args: + mot_kwargs["image_height"] = self.additional_args["image_height"] + if "video_name" in self.additional_args: + mot_kwargs["video_file"] = self.additional_args["video_name"] + if annotations_file.is_dir(): + video_files = self.additional_args.get("video_files") + raw_datasource_path = self.additional_args.get("datasource_path") + if raw_datasource_path is None: + raw_datasource_path = self.ds.source.source_prefix + datasource_path = PurePosixPath(raw_datasource_path).as_posix().lstrip("/") + if datasource_path == ".": + datasource_path = "" + mot_results = load_mot_from_fs( + annotations_file, + image_width=mot_kwargs.get("image_width"), + image_height=mot_kwargs.get("image_height"), + video_files=video_files, + datasource_path=datasource_path, + ) + annotation_dict = self._flatten_mot_fs_annotations(mot_results) + elif annotations_file.suffix == ".zip": + video_anns, _ = load_mot_from_zip(annotations_file, **mot_kwargs) + annotation_dict = self._flatten_video_annotations(video_anns) + else: + video_anns, _ = load_mot_from_dir(annotations_file, **mot_kwargs) + annotation_dict = self._flatten_video_annotations(video_anns) + elif self.annotations_type == "cvat_video": + cvat_kwargs = {} + if "image_width" in self.additional_args: + cvat_kwargs["image_width"] = self.additional_args["image_width"] + if "image_height" in self.additional_args: + cvat_kwargs["image_height"] = self.additional_args["image_height"] + if annotations_file.is_dir(): + raw = load_cvat_from_fs(annotations_file, **cvat_kwargs) + annotation_dict = self._flatten_cvat_fs_annotations(raw) + elif annotations_file.suffix == ".zip": + result = load_cvat_from_zip(annotations_file, **cvat_kwargs) + if self._is_video_annotation_dict(result): + annotation_dict = self._flatten_video_annotations(result) + else: + annotation_dict = result + else: + result = load_cvat_from_xml_file(annotations_file, **cvat_kwargs) + if self._is_video_annotation_dict(result): + annotation_dict = self._flatten_video_annotations(result) + else: + annotation_dict = result + else: + raise ValueError(f"Unsupported annotation type: {self.annotations_type}") return annotation_dict + @staticmethod + def _is_video_annotation_dict(result) -> bool: + """Check if the result from a CVAT loader is video annotations (int keys) vs image annotations (str keys).""" + if not isinstance(result, dict) or len(result) == 0: + return False + first_key = next(iter(result.keys())) + return isinstance(first_key, int) + + def _flatten_video_annotations( + self, + frame_annotations: Dict[int, Sequence[IRAnnotationBase]], + ) -> Dict[str, Sequence[IRAnnotationBase]]: + """Flatten frame-indexed video annotations into a single entry keyed by video name.""" + video_name = self.additional_args.get("video_name", self.annotations_file.stem) + all_anns: List[IRAnnotationBase] = [] + for frame_anns in frame_annotations.values(): + all_anns.extend(frame_anns) + return {video_name: all_anns} + + def _flatten_cvat_fs_annotations( + self, fs_annotations: Mapping[str, object] + ) -> Dict[str, Sequence[IRAnnotationBase]]: + flattened: Dict[str, List[IRAnnotationBase]] = {} + for rel_path, result in fs_annotations.items(): + if not isinstance(result, dict): + continue + if self._is_video_annotation_dict(result): + video_key = Path(rel_path).stem + flattened.setdefault(video_key, []) + for frame_anns in result.values(): + flattened[video_key].extend(frame_anns) + else: + for filename, anns in result.items(): + flattened.setdefault(filename, []) + flattened[filename].extend(anns) + return flattened + + def _flatten_mot_fs_annotations( + self, + fs_annotations: Mapping[str, object], + ) -> Dict[str, Sequence[IRAnnotationBase]]: + flattened: Dict[str, List[IRAnnotationBase]] = {} + for rel_path, result in fs_annotations.items(): + if not isinstance(result, tuple) or len(result) != 2: + continue + frame_annotations = result[0] + if not isinstance(frame_annotations, dict): + continue + sequence_name = Path(rel_path).stem if rel_path not in (".", "") else self.annotations_file.stem + flattened.setdefault(sequence_name, []) + for frame_anns in frame_annotations.values(): + flattened[sequence_name].extend(frame_anns) + return flattened + def download_annotations(self, dest_dir: Path): log_message("Downloading annotations from repository") repoApi = self.ds.source.repoApi - if self.annotations_type == "cvat": - # Download just the annotation file + if self.annotations_type in ("cvat", "cvat_video"): repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True) elif self.annotations_type == "yolo": # Download the dataset .yaml file and the images + annotations @@ -104,6 +231,8 @@ def download_annotations(self, dest_dir: Path): # Download the annotation data assert context.path is not None repoApi.download(self.annotations_file.parent / context.path, dest_dir, keep_source_prefix=True) + elif self.annotations_type in ("coco", "mot"): + repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True) @staticmethod def determine_load_location(ds: "Datasource", annotations_path: Union[str, Path]) -> AnnotationLocation: @@ -153,8 +282,12 @@ def remap_annotations( ) continue for ann in anns: - assert ann.filename is not None - ann.filename = remap_func(ann.filename) + if ann.filename is not None: + ann.filename = remap_func(ann.filename) + else: + if not self.is_video_format: + raise ValueError(f"Non-video annotation has no filename: {ann}") + ann.filename = new_filename remapped[new_filename] = anns return remapped @@ -288,6 +421,8 @@ def convert_to_ls_tasks(self, annotations: Mapping[str, Sequence[IRAnnotationBas """ Converts the annotations to Label Studio tasks. """ + if self.is_video_format: + return self._convert_to_ls_video_tasks(annotations) current_user_id = UserAPI.get_current_user(self.ds.source.repoApi.host).user_id tasks = {} for filename, anns in annotations.items(): @@ -296,3 +431,20 @@ def convert_to_ls_tasks(self, annotations: Mapping[str, Sequence[IRAnnotationBas t.add_ir_annotations(anns) tasks[filename] = t.model_dump_json().encode("utf-8") return tasks + + def _convert_to_ls_video_tasks( + self, annotations: Mapping[str, Sequence[IRAnnotationBase]] + ) -> Mapping[str, bytes]: + """ + Converts video annotations to Label Studio video tasks. + """ + tasks = {} + for filename, anns in annotations.items(): + video_anns = [a for a in anns if isinstance(a, IRVideoBBoxAnnotation)] + if not video_anns: + continue + video_path = self.ds.source.raw_path(filename) + ls_tasks = video_ir_to_ls_video_tasks(video_anns, video_path=video_path) + if ls_tasks: + tasks[filename] = ls_tasks[0].model_dump_json().encode("utf-8") + return tasks diff --git a/dagshub/data_engine/annotation/metadata.py b/dagshub/data_engine/annotation/metadata.py index 06f7bc28..0b080e0f 100644 --- a/dagshub/data_engine/annotation/metadata.py +++ b/dagshub/data_engine/annotation/metadata.py @@ -22,6 +22,11 @@ from dagshub.data_engine.model.datapoint import Datapoint +from dagshub_annotation_converter.formats.label_studio.videorectangle import VideoRectangleAnnotation +from dagshub_annotation_converter.formats.label_studio.task import task_lookup as _task_lookup + +_task_lookup["videorectangle"] = VideoRectangleAnnotation + class AnnotationMetaDict(dict): def __init__(self, annotation: "MetadataAnnotations", *args, **kwargs): @@ -271,6 +276,28 @@ def add_image_pose( self.annotations.append(ann) self._update_datapoint() + def add_coco_annotation( + self, + coco_json: str, + ): + """ + Add annotations from a COCO-format JSON string. + + Args: + coco_json: A COCO-format JSON string with ``categories``, ``images``, and ``annotations`` keys. + """ + from dagshub_annotation_converter.converters.coco import load_coco_from_json_string + + grouped, _ = load_coco_from_json_string(coco_json) + new_anns: list[IRAnnotationBase] = [] + for anns in grouped.values(): + for ann in anns: + ann.filename = self.datapoint.path + new_anns.append(ann) + self.annotations.extend(new_anns) + log_message(f"Added {len(new_anns)} COCO annotation(s) to datapoint {self.datapoint.path}") + self._update_datapoint() + def add_yolo_annotation( self, annotation_type: Literal["bbox", "segmentation", "pose"], diff --git a/dagshub/data_engine/model/query_result.py b/dagshub/data_engine/model/query_result.py index b986b5c3..3c283194 100644 --- a/dagshub/data_engine/model/query_result.py +++ b/dagshub/data_engine/model/query_result.py @@ -15,10 +15,16 @@ import dacite import dagshub_annotation_converter.converters.yolo import rich.progress +from dagshub_annotation_converter.converters.coco import export_to_coco_file +from dagshub_annotation_converter.converters.cvat import export_cvat_video_to_zip, export_cvat_videos_to_zips +from dagshub_annotation_converter.converters.mot import export_mot_sequences_to_dirs, export_mot_to_dir +from dagshub_annotation_converter.formats.coco import CocoContext +from dagshub_annotation_converter.formats.mot import MOTContext from dagshub_annotation_converter.formats.yolo import YoloContext from dagshub_annotation_converter.formats.yolo.categories import Categories from dagshub_annotation_converter.formats.yolo.common import ir_mapping from dagshub_annotation_converter.ir.image import IRImageAnnotationBase +from dagshub_annotation_converter.ir.video import IRVideoBBoxAnnotation from pydantic import ValidationError from dagshub.auth import get_token @@ -778,6 +784,44 @@ def _get_all_annotations(self, annotation_field: str) -> List[IRImageAnnotationB annotations.extend(dp.metadata[annotation_field].annotations) return annotations + def _get_all_video_annotations(self, annotation_field: str) -> List[IRVideoBBoxAnnotation]: + all_anns = self._get_all_annotations(annotation_field) + return [a for a in all_anns if isinstance(a, IRVideoBBoxAnnotation)] + + def _prepare_video_file_for_export(self, local_root: Path, repo_relative_filename: str) -> Optional[Path]: + ann_path = Path(repo_relative_filename) + primary = local_root / ann_path + if primary.exists(): + return primary + source_prefix = Path(self.datasource.source.source_prefix) + with_prefix = local_root / source_prefix / ann_path + if with_prefix.exists(): + return with_prefix + return None + + @staticmethod + def _get_annotation_filename(ann: IRVideoBBoxAnnotation) -> Optional[str]: + filename = ann.filename + if filename is None: + return None + if isinstance(filename, (list, tuple)): + if len(filename) == 0: + return None + if len(filename) > 1: + raise ValueError(f"Annotation has multiple filenames: {filename}") + filename = filename[0] + return str(filename) + + def _resolve_annotation_field(self, annotation_field: Optional[str]) -> str: + if annotation_field is not None: + return annotation_field + annotation_fields = sorted([f.name for f in self.fields if f.is_annotation()]) + if len(annotation_fields) == 0: + raise ValueError("No annotation fields found in the datasource") + annotation_field = annotation_fields[0] + log_message(f"Using annotations from field {annotation_field}") + return annotation_field + def export_as_yolo( self, download_dir: Optional[Union[str, Path]] = None, @@ -803,12 +847,7 @@ def export_as_yolo( Returns: The path to the YAML file with the metadata. Pass this path to ``YOLO.train()`` to train a model. """ - if annotation_field is None: - annotation_fields = sorted([f.name for f in self.fields if f.is_annotation()]) - if len(annotation_fields) == 0: - raise ValueError("No annotation fields found in the datasource") - annotation_field = annotation_fields[0] - log_message(f"Using annotations from field {annotation_field}") + annotation_field = self._resolve_annotation_field(annotation_field) if download_dir is None: download_dir = Path("dagshub_export") @@ -861,6 +900,266 @@ def export_as_yolo( log_message(f"Done! Saved YOLO Dataset, YAML file is at {yaml_path.absolute()}") return yaml_path + def export_as_coco( + self, + download_dir: Optional[Union[str, Path]] = None, + annotation_field: Optional[str] = None, + output_filename: str = "annotations.json", + classes: Optional[Dict[int, str]] = None, + ) -> Path: + """ + Downloads the files and exports annotations in COCO format. + + Args: + download_dir: Where to download the files. Defaults to ``./dagshub_export`` + annotation_field: Field with the annotations. If None, uses the first alphabetical annotation field. + output_filename: Name of the output COCO JSON file. Default is ``annotations.json``. + classes: Category mapping for the COCO dataset as ``{id: name}``. + If ``None``, categories will be inferred from the annotations. + + Returns: + Path to the exported COCO JSON file. + """ + annotation_field = self._resolve_annotation_field(annotation_field) + + if download_dir is None: + download_dir = Path("dagshub_export") + download_dir = Path(download_dir) + + annotations = self._get_all_annotations(annotation_field) + if not annotations: + raise RuntimeError("No annotations found to export") + + context = CocoContext() + if classes is not None: + context.categories = dict(classes) + + # Add the source prefix to all annotations + for ann in annotations: + ann.filename = os.path.join(self.datasource.source.source_prefix, ann.filename) + + image_download_path = download_dir / "data" + log_message("Downloading image files...") + self.download_files(image_download_path) + + output_path = download_dir / output_filename + log_message("Exporting COCO annotations...") + result_path = export_to_coco_file(annotations, output_path, context=context) + log_message(f"Done! Saved COCO annotations to {result_path.absolute()}") + return result_path + + def export_as_mot( + self, + download_dir: Optional[Union[str, Path]] = None, + annotation_field: Optional[str] = None, + image_width: Optional[int] = None, + image_height: Optional[int] = None, + ) -> Path: + """ + Exports video annotations in MOT (Multiple Object Tracking) format. + + The output follows the MOT Challenge directory structure:: + + output_dir/ + gt/ + gt.txt + labels.txt + seqinfo.ini + + Args: + download_dir: Where to export. Defaults to ``./dagshub_export`` + annotation_field: Field with the annotations. If None, uses the first alphabetical annotation field. + image_width: Frame width. If None, inferred from annotations. + image_height: Frame height. If None, inferred from annotations. + + Returns: + Path to the exported MOT directory. + """ + annotation_field = self._resolve_annotation_field(annotation_field) + + if download_dir is None: + download_dir = Path("dagshub_export") + download_dir = Path(download_dir) + labels_dir = download_dir / "labels" + labels_dir.mkdir(parents=True, exist_ok=True) + + video_annotations = self._get_all_video_annotations(annotation_field) + if not video_annotations: + raise RuntimeError("No video annotations found to export") + + source_names = sorted( + { + Path(ann_filename).name + for ann_filename in (self._get_annotation_filename(ann) for ann in video_annotations) + if ann_filename + } + ) + has_multiple_sources = len(source_names) > 1 + + local_download_root: Optional[Path] = None + if image_width is None or image_height is None: + log_message("Missing video dimensions in annotations, downloading videos for converter-side probing...") + local_download_root = self.download_files(download_dir / "data", keep_source_prefix=True) + + log_message("Exporting MOT annotations...") + if has_multiple_sources: + video_files: Optional[Dict[str, Union[str, Path]]] = None + if local_download_root is not None: + video_files = {} + for ann_filename in { + self._get_annotation_filename(ann) + for ann in video_annotations + if self._get_annotation_filename(ann) + }: + assert ann_filename is not None + sequence_name = Path(ann_filename).stem + local_video = self._prepare_video_file_for_export(local_download_root, ann_filename) + if local_video is None: + raise FileNotFoundError( + f"Could not find local downloaded video file for '{ann_filename}' under " + f"'{local_download_root}'." + ) + video_files[sequence_name] = local_video + + context = MOTContext() + context.image_width = image_width + context.image_height = image_height + export_mot_sequences_to_dirs(video_annotations, context, labels_dir, video_files=video_files) + result_path = labels_dir + else: + video_file: Optional[Path] = None + if local_download_root is not None: + ref_filename = next((self._get_annotation_filename(a) for a in video_annotations), None) + if ref_filename is None: + raise FileNotFoundError("Missing annotation filename for MOT export.") + video_file = self._prepare_video_file_for_export(local_download_root, ref_filename) + if video_file is None: + raise FileNotFoundError( + f"Could not find local downloaded video file for '{ref_filename}' " + f"under '{local_download_root}'." + ) + + context = MOTContext() + context.image_width = image_width + context.image_height = image_height + single_name = Path(source_names[0]).stem if source_names else "sequence" + output_dir = labels_dir / single_name + result_path = export_mot_to_dir(video_annotations, context, output_dir, video_file=video_file) + + log_message(f"Done! Saved MOT annotations to {result_path.absolute()}") + return result_path + + def export_as_cvat_video( + self, + download_dir: Optional[Union[str, Path]] = None, + annotation_field: Optional[str] = None, + video_name: str = "video.mp4", + image_width: Optional[int] = None, + image_height: Optional[int] = None, + ) -> Path: + """ + Exports video annotations in CVAT video ZIP format. + + Args: + download_dir: Where to export. Defaults to ``./dagshub_export`` + annotation_field: Field with the annotations. If None, uses the first alphabetical annotation field. + video_name: Name of the source video to embed in the XML metadata. + image_width: Frame width. If None, inferred from annotations. + image_height: Frame height. If None, inferred from annotations. + + Returns: + Path to the exported CVAT video ZIP file for single-video exports, + or output directory for multi-video exports. + """ + annotation_field = self._resolve_annotation_field(annotation_field) + + if download_dir is None: + download_dir = Path("dagshub_export") + download_dir = Path(download_dir) + + video_annotations = self._get_all_video_annotations(annotation_field) + if not video_annotations: + raise RuntimeError("No video annotations found to export") + + source_names = sorted( + { + Path(ann_filename).name + for ann_filename in (self._get_annotation_filename(ann) for ann in video_annotations) + if ann_filename + } + ) + has_multiple_sources = len(source_names) > 1 + + log_message("Exporting CVAT video annotations...") + local_download_root: Optional[Path] = None + if not has_multiple_sources and (image_width is None or image_height is None): + log_message("Missing video dimensions in annotations, downloading videos for converter-side probing...") + local_download_root = self.download_files(download_dir / "data", keep_source_prefix=True) + + if has_multiple_sources: + video_files: Optional[Dict[str, Union[str, Path]]] = None + if image_width is None or image_height is None: + log_message("Missing video dimensions in annotations, downloading videos for converter-side probing...") + local_download_root = self.download_files(download_dir / "data", keep_source_prefix=True) + video_files = {} + for ann_filename in { + self._get_annotation_filename(ann) + for ann in video_annotations + if self._get_annotation_filename(ann) + }: + assert ann_filename is not None + local_video = self._prepare_video_file_for_export(local_download_root, ann_filename) + if local_video is None: + raise FileNotFoundError( + f"Could not find local downloaded video file for '{ann_filename}' " + f"under '{local_download_root}'." + ) + ann_path = Path(ann_filename) + video_files[ann_filename] = local_video + video_files[ann_path.name] = local_video + video_files[ann_path.stem] = local_video + + output_dir = download_dir / "labels" + output_dir.mkdir(parents=True, exist_ok=True) + export_cvat_videos_to_zips( + video_annotations, + output_dir, + image_width=image_width, + image_height=image_height, + video_files=video_files if video_files else None, + ) + result_path = output_dir + else: + single_video_file: Optional[Path] = None + if local_download_root is not None: + ref_filename = next((self._get_annotation_filename(a) for a in video_annotations), None) + if ref_filename is None: + raise FileNotFoundError("Missing annotation filename for single-video CVAT export.") + single_video_file = self._prepare_video_file_for_export(local_download_root, ref_filename) + if single_video_file is None: + raise FileNotFoundError( + f"Could not find local downloaded video file for '{ref_filename}' " + f"under '{local_download_root}'." + ) + + labels_dir = download_dir / "labels" + labels_dir.mkdir(parents=True, exist_ok=True) + if source_names: + output_name = f"{Path(source_names[0]).name}.zip" + else: + output_name = "annotations.zip" + output_path = labels_dir / output_name + result_path = export_cvat_video_to_zip( + video_annotations, + output_path, + video_name=video_name, + image_width=image_width, + image_height=image_height, + video_file=single_video_file, + ) + log_message(f"Done! Saved CVAT video annotations to {result_path.absolute()}") + return result_path + def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset": """ Creates a voxel51 dataset that can be used with\ diff --git a/tests/data_engine/annotation_import/test_coco.py b/tests/data_engine/annotation_import/test_coco.py new file mode 100644 index 00000000..9b238fd1 --- /dev/null +++ b/tests/data_engine/annotation_import/test_coco.py @@ -0,0 +1,218 @@ +import datetime +import json +from pathlib import PurePosixPath +from unittest.mock import patch, PropertyMock + +import pytest +from dagshub_annotation_converter.ir.image import ( + IRBBoxImageAnnotation, + CoordinateStyle, +) + +from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationsNotFoundError +from dagshub.data_engine.annotation.metadata import MetadataAnnotations +from dagshub.data_engine.client.models import MetadataSelectFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags +from dagshub.data_engine.model.datapoint import Datapoint +from dagshub.data_engine.model.query_result import QueryResult + + +@pytest.fixture(autouse=True) +def mock_source_prefix(ds): + with patch.object(type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath()): + yield + + +# --- import --- + + +def test_import_coco_from_file(ds, tmp_path): + _write_coco(tmp_path, _make_coco_json()) + importer = AnnotationImporter(ds, "coco", tmp_path / "annotations.json", load_from="disk") + result = importer.import_annotations() + + assert "image1.jpg" in result + assert len(result["image1.jpg"]) == 1 + assert isinstance(result["image1.jpg"][0], IRBBoxImageAnnotation) + + +def test_import_coco_nonexistent_raises(ds, tmp_path): + importer = AnnotationImporter(ds, "coco", tmp_path / "nope.json", load_from="disk") + with pytest.raises(AnnotationsNotFoundError): + importer.import_annotations() + + +def test_coco_convert_to_ls_tasks(ds, tmp_path, mock_dagshub_auth): + importer = AnnotationImporter(ds, "coco", tmp_path / "ann.json", load_from="disk") + bbox = IRBBoxImageAnnotation( + filename="test.jpg", categories={"cat": 1.0}, + top=0.1, left=0.1, width=0.2, height=0.2, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.NORMALIZED, + ) + tasks = importer.convert_to_ls_tasks({"test.jpg": [bbox]}) + + assert "test.jpg" in tasks + task_json = json.loads(tasks["test.jpg"]) + assert "annotations" in task_json + assert len(task_json["annotations"]) > 0 + + +# --- add_coco_annotation --- + + +def test_add_coco_annotation_rewrites_filename(ds, mock_dagshub_auth): + dp = Datapoint(datasource=ds, path="my_images/photo.jpg", datapoint_id=0, metadata={}) + meta_ann = MetadataAnnotations(datapoint=dp, field="ann") + meta_ann.add_coco_annotation(json.dumps(_make_coco_json())) + + assert len(meta_ann.annotations) == 1 + assert isinstance(meta_ann.annotations[0], IRBBoxImageAnnotation) + assert meta_ann.annotations[0].filename == "my_images/photo.jpg" + + +# --- _resolve_annotation_field --- + + +def test_resolve_explicit_field(ds): + qr = _make_qr(ds, [], ann_field="my_ann") + assert qr._resolve_annotation_field("explicit") == "explicit" + + +def test_resolve_auto_field(ds): + qr = _make_qr(ds, [], ann_field="my_ann") + assert qr._resolve_annotation_field(None) == "my_ann" + + +def test_resolve_no_fields_raises(ds): + qr = _make_qr(ds, [], ann_field=None) + with pytest.raises(ValueError, match="No annotation fields"): + qr._resolve_annotation_field(None) + + +def test_resolve_picks_alphabetically_first(ds): + fields = [] + for name in ["zebra_ann", "alpha_ann"]: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=name, + multiple=False, valueType=MetadataFieldType.BLOB, + name=name, tags={ReservedTags.ANNOTATION.value}, + )) + qr = QueryResult(datasource=ds, _entries=[], fields=fields) + assert qr._resolve_annotation_field(None) == "alpha_ann" + + +# --- export_as_coco --- + + +def test_export_coco_bbox_coordinates(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + ann = IRBBoxImageAnnotation( + filename="images/test.jpg", categories={"cat": 1.0}, + top=20.0, left=10.0, width=30.0, height=40.0, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[ann]) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + coco = json.loads(result.read_text()) + assert coco["annotations"][0]["bbox"] == [10.0, 20.0, 30.0, 40.0] + + +def test_export_coco_no_annotations_raises(ds, tmp_path): + dp = Datapoint(datasource=ds, path="test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[]) + + qr = _make_qr(ds, [dp], ann_field="ann") + with pytest.raises(RuntimeError, match="No annotations found"): + qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + +def test_export_coco_explicit_classes(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox("images/test.jpg")] + ) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco( + download_dir=tmp_path, annotation_field="ann", classes={1: "cat", 2: "dog"} + ) + + coco = json.loads(result.read_text()) + assert "cat" in {c["name"] for c in coco["categories"]} + + +def test_export_coco_custom_filename(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox("images/test.jpg")] + ) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco( + download_dir=tmp_path, annotation_field="ann", output_filename="custom.json" + ) + + assert result.name == "custom.json" + + +def test_export_coco_multiple_datapoints(ds, tmp_path): + dps = [] + for i, name in enumerate(["a.jpg", "b.jpg"]): + dp = Datapoint(datasource=ds, path=name, datapoint_id=i, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox(name)] + ) + dps.append(dp) + + qr = _make_qr(ds, dps, ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + coco = json.loads(result.read_text()) + assert len(coco["annotations"]) == 2 + assert len(coco["images"]) == 2 + + +# --- helpers --- + + +def _make_coco_json(): + return { + "categories": [{"id": 1, "name": "cat"}], + "images": [{"id": 1, "width": 640, "height": 480, "file_name": "image1.jpg"}], + "annotations": [{"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 30, 40]}], + } + + +def _write_coco(tmp_path, coco): + (tmp_path / "annotations.json").write_text(json.dumps(coco)) + + +def _make_image_bbox(filename="test.jpg") -> IRBBoxImageAnnotation: + return IRBBoxImageAnnotation( + filename=filename, categories={"cat": 1.0}, + top=20.0, left=10.0, width=30.0, height=40.0, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + + +def _make_qr(ds, datapoints, ann_field=None): + fields = [] + if ann_field: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=ann_field, + multiple=False, valueType=MetadataFieldType.BLOB, + name=ann_field, tags={ReservedTags.ANNOTATION.value}, + )) + return QueryResult(datasource=ds, _entries=datapoints, fields=fields) diff --git a/tests/data_engine/annotation_import/test_cvat_video.py b/tests/data_engine/annotation_import/test_cvat_video.py new file mode 100644 index 00000000..3676b82d --- /dev/null +++ b/tests/data_engine/annotation_import/test_cvat_video.py @@ -0,0 +1,276 @@ +import datetime +import zipfile +from pathlib import PurePosixPath +from unittest.mock import patch, PropertyMock + +import pytest +from dagshub_annotation_converter.converters.cvat import export_cvat_video_to_xml_string +from dagshub_annotation_converter.ir.image import IRBBoxImageAnnotation, CoordinateStyle +from dagshub_annotation_converter.ir.video import IRVideoBBoxAnnotation + +from dagshub.data_engine.annotation.importer import AnnotationImporter +from dagshub.data_engine.annotation.metadata import MetadataAnnotations +from dagshub.data_engine.client.models import MetadataSelectFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags +from dagshub.data_engine.model.datapoint import Datapoint +from dagshub.data_engine.model.query_result import QueryResult + + +@pytest.fixture(autouse=True) +def mock_source_prefix(ds): + with patch.object(type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath()): + yield + + +# --- import --- + + +def test_import_cvat_video(ds, tmp_path): + xml_file = tmp_path / "annotations.xml" + xml_file.write_bytes(_make_cvat_video_xml()) + + importer = AnnotationImporter(ds, "cvat_video", xml_file, load_from="disk") + result = importer.import_annotations() + + assert len(result) == 1 + anns = list(result.values())[0] + assert len(anns) == 2 + assert all(isinstance(a, IRVideoBBoxAnnotation) for a in anns) + + +# --- _get_all_video_annotations --- + + +def test_get_all_video_filters(ds): + image_ann = IRBBoxImageAnnotation( + filename="test.jpg", categories={"cat": 1.0}, + top=0.1, left=0.1, width=0.2, height=0.2, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.NORMALIZED, + ) + video_ann = _make_video_bbox() + + dp = Datapoint(datasource=ds, path="dp_0", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[image_ann, video_ann] + ) + + qr = _make_qr(ds, [dp], ann_field="ann") + result = qr._get_all_video_annotations("ann") + assert len(result) == 1 + assert isinstance(result[0], IRVideoBBoxAnnotation) + + +def test_get_all_video_empty(ds): + dp = Datapoint(datasource=ds, path="dp_0", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[]) + + qr = _make_qr(ds, [dp], ann_field="ann") + assert qr._get_all_video_annotations("ann") == [] + + +def test_get_all_video_aggregates_across_datapoints(ds): + dps = [] + for i in range(3): + dp = Datapoint(datasource=ds, path=f"dp_{i}", datapoint_id=i, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_video_bbox(frame=i)] + ) + dps.append(dp) + + qr = _make_qr(ds, dps, ann_field="ann") + assert len(qr._get_all_video_annotations("ann")) == 3 + + +# --- export_as_cvat_video --- + + +def test_export_cvat_video_xml(ds, tmp_path, monkeypatch): + qr, _ = _make_video_qr(ds) + + def _mock_download_files(self, target_dir, *args, **kwargs): + (target_dir / "video.mp4").parent.mkdir(parents=True, exist_ok=True) + (target_dir / "video.mp4").write_bytes(b"fake") + return target_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + result = qr.export_as_cvat_video(download_dir=tmp_path, annotation_field="ann") + + assert result.exists() + assert result == tmp_path / "labels" / "video.mp4.zip" + with zipfile.ZipFile(result, "r") as z: + content = z.read("annotations.xml").decode("utf-8") + assert "") + return output_path + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr( + "dagshub.data_engine.model.query_result.export_cvat_video_to_zip", + _mock_export_cvat_video_to_zip, + ) + + qr.export_as_cvat_video(download_dir=tmp_path, annotation_field="ann") + + assert captured["video_file"] is not None + assert captured["video_file"].endswith("video.mp4") + + +def test_export_cvat_video_missing_local_file_raises(ds, tmp_path, monkeypatch): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + ann = _make_video_bbox(frame=0, track_id=0) + ann.image_width = 0 + ann.image_height = 0 + ann.filename = "missing.mp4" + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[ann]) + qr = _make_qr(ds, [dp], ann_field="ann") + + def _mock_download_files(self, target_dir, *args, **kwargs): + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + + with pytest.raises(FileNotFoundError, match="missing.mp4"): + qr.export_as_cvat_video(download_dir=tmp_path, annotation_field="ann") + + +# --- helpers --- + + +def _make_video_bbox(frame=0, track_id=0) -> IRVideoBBoxAnnotation: + return IRVideoBBoxAnnotation( + track_id=track_id, frame_number=frame, + left=100.0, top=150.0, width=50.0, height=80.0, + image_width=1920, image_height=1080, + categories={"person": 1.0}, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + + +def _make_cvat_video_xml() -> bytes: + anns = [_make_video_bbox(frame=0, track_id=0), _make_video_bbox(frame=5, track_id=0)] + return export_cvat_video_to_xml_string(anns) + + +def _make_video_qr(ds): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + anns = [_make_video_bbox(frame=0, track_id=0), _make_video_bbox(frame=5, track_id=0)] + for ann in anns: + ann.filename = "video.mp4" + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=anns) + qr = _make_qr(ds, [dp], ann_field="ann") + return qr, dp + + +def _make_qr(ds, datapoints, ann_field=None): + fields = [] + if ann_field: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=ann_field, + multiple=False, valueType=MetadataFieldType.BLOB, + name=ann_field, tags={ReservedTags.ANNOTATION.value}, + )) + return QueryResult(datasource=ds, _entries=datapoints, fields=fields) diff --git a/tests/data_engine/annotation_import/test_mot.py b/tests/data_engine/annotation_import/test_mot.py new file mode 100644 index 00000000..9070e676 --- /dev/null +++ b/tests/data_engine/annotation_import/test_mot.py @@ -0,0 +1,383 @@ +import configparser +import datetime +import json +import zipfile +from pathlib import Path, PurePosixPath +from unittest.mock import patch, PropertyMock + +import pytest +from dagshub_annotation_converter.ir.image import CoordinateStyle +from dagshub_annotation_converter.ir.video import IRVideoBBoxAnnotation + +from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationsNotFoundError +from dagshub.data_engine.annotation.metadata import MetadataAnnotations +from dagshub.data_engine.client.models import MetadataSelectFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags +from dagshub.data_engine.model.datapoint import Datapoint +from dagshub.data_engine.model.query_result import QueryResult + + +@pytest.fixture(autouse=True) +def mock_source_prefix(ds): + with patch.object(type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath()): + yield + + +# --- _is_video_annotation_dict --- + + +def test_is_video_dict_int_keys(): + assert AnnotationImporter._is_video_annotation_dict({0: [], 1: []}) is True + + +def test_is_video_dict_str_keys(): + assert AnnotationImporter._is_video_annotation_dict({"file.jpg": []}) is False + + +def test_is_video_dict_empty(): + assert AnnotationImporter._is_video_annotation_dict({}) is False + + +def test_is_video_dict_non_dict(): + assert AnnotationImporter._is_video_annotation_dict([]) is False + + +def test_is_video_dict_mixed_first_int(): + assert AnnotationImporter._is_video_annotation_dict({0: [], "a": []}) is True + + +# --- is_video_format --- + + +@pytest.mark.parametrize( + "ann_type, expected", + [ + ("yolo", False), + ("cvat", False), + ("coco", False), + ("mot", True), + ("cvat_video", True), + ], +) +def test_is_video_format(ds, ann_type, expected, tmp_path): + kwargs = {} + if ann_type == "yolo": + kwargs["yolo_type"] = "bbox" + importer = AnnotationImporter(ds, ann_type, tmp_path / "dummy", load_from="disk", **kwargs) + assert importer.is_video_format is expected + + +# --- _flatten_video_annotations --- + + +def test_flatten_merges_frames(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "test_video", load_from="disk") + result = importer._flatten_video_annotations({ + 0: [_make_video_bbox(frame=0)], + 5: [_make_video_bbox(frame=5)], + }) + assert "test_video" in result + assert len(result["test_video"]) == 2 + + +def test_flatten_defaults_to_file_stem(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "my_sequence", load_from="disk") + result = importer._flatten_video_annotations({0: [_make_video_bbox()]}) + assert "my_sequence" in result + + +def test_flatten_video_name_override(ds, tmp_path): + importer = AnnotationImporter( + ds, "mot", tmp_path / "test_video", load_from="disk", video_name="custom.mp4" + ) + result = importer._flatten_video_annotations({0: [_make_video_bbox()]}) + assert "custom.mp4" in result + + +# --- import --- + + +def test_import_mot_from_dir(ds, tmp_path): + mot_dir = tmp_path / "mot_seq" + _create_mot_dir(mot_dir) + + importer = AnnotationImporter(ds, "mot", mot_dir, load_from="disk") + result = importer.import_annotations() + + assert len(result) == 1 + anns = list(result.values())[0] + assert len(anns) == 2 + assert all(isinstance(a, IRVideoBBoxAnnotation) for a in anns) + + +def test_import_mot_from_zip(ds, tmp_path): + mot_dir = tmp_path / "mot_seq" + _create_mot_dir(mot_dir) + zip_path = _zip_mot_dir(tmp_path, mot_dir) + + importer = AnnotationImporter(ds, "mot", zip_path, load_from="disk") + result = importer.import_annotations() + + assert len(result) == 1 + assert len(list(result.values())[0]) == 2 + + +def test_import_mot_from_fs_passes_datasource_path_from_source_prefix(ds, tmp_path, monkeypatch): + captured = {} + + def _mock_load_mot_from_fs(import_dir, image_width=None, image_height=None, video_files=None, datasource_path=""): + captured["import_dir"] = import_dir + captured["image_width"] = image_width + captured["image_height"] = image_height + captured["video_files"] = video_files + captured["datasource_path"] = datasource_path + return {"seq_a": ({0: [_make_video_bbox(frame=0)]}, object())} + + monkeypatch.setattr("dagshub.data_engine.annotation.importer.load_mot_from_fs", _mock_load_mot_from_fs) + + with patch.object( + type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath("data/videos") + ): + importer = AnnotationImporter( + ds, + "mot", + tmp_path, + load_from="disk", + image_width=1280, + image_height=720, + video_files={"seq_a": "dummy.mp4"}, + ) + result = importer.import_annotations() + + assert captured["datasource_path"] == "data/videos" + assert captured["video_files"] == {"seq_a": "dummy.mp4"} + assert captured["image_width"] == 1280 + assert captured["image_height"] == 720 + assert "seq_a" in result + + +def test_import_mot_nonexistent_raises(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "missing", load_from="disk") + with pytest.raises(AnnotationsNotFoundError): + importer.import_annotations() + + +# --- convert_to_ls_tasks --- + + +def test_convert_video_to_ls_tasks(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "video", load_from="disk") + video_anns = {"video.mp4": [_make_video_bbox(frame=0), _make_video_bbox(frame=1)]} + tasks = importer.convert_to_ls_tasks(video_anns) + + assert "video.mp4" in tasks + task_json = json.loads(tasks["video.mp4"]) + assert "annotations" in task_json + + +def test_convert_video_empty_skipped(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "video", load_from="disk") + tasks = importer.convert_to_ls_tasks({"video.mp4": []}) + assert "video.mp4" not in tasks + + +# --- export_as_mot --- + + +def test_export_mot_directory_structure(ds, tmp_path, monkeypatch): + qr, _ = _make_video_qr(ds) + + def _mock_download_files(self, target_dir, *args, **kwargs): + (target_dir / "video.mp4").parent.mkdir(parents=True, exist_ok=True) + (target_dir / "video.mp4").write_bytes(b"fake") + return target_dir + + def _mock_export_mot_to_dir(video_annotations, context, output_dir, video_file=None): + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "gt").mkdir(parents=True, exist_ok=True) + (output_dir / "gt" / "gt.txt").write_text("") + (output_dir / "gt" / "labels.txt").write_text("person\n") + config = configparser.ConfigParser() + config["Sequence"] = {"imWidth": "1920", "imHeight": "1080"} + with open(output_dir / "seqinfo.ini", "w") as f: + config.write(f) + return output_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr( + "dagshub.data_engine.model.query_result.export_mot_to_dir", + _mock_export_mot_to_dir, + ) + result = qr.export_as_mot(download_dir=tmp_path, annotation_field="ann") + + assert result.exists() + assert result == tmp_path / "labels" / "video" + assert (result / "gt" / "gt.txt").exists() + assert (result / "gt" / "labels.txt").exists() + assert (result / "seqinfo.ini").exists() + + +def test_export_mot_explicit_dimensions(ds, tmp_path, monkeypatch): + qr, _ = _make_video_qr(ds) + + def _mock_export_mot_to_dir(video_annotations, context, output_dir, video_file=None): + output_dir.mkdir(parents=True, exist_ok=True) + config = configparser.ConfigParser() + config["Sequence"] = { + "imWidth": str(context.image_width), + "imHeight": str(context.image_height), + } + with open(output_dir / "seqinfo.ini", "w") as f: + config.write(f) + (output_dir / "gt").mkdir(parents=True, exist_ok=True) + (output_dir / "gt" / "gt.txt").write_text("") + (output_dir / "gt" / "labels.txt").write_text("person\n") + return output_dir + + monkeypatch.setattr( + "dagshub.data_engine.model.query_result.export_mot_to_dir", + _mock_export_mot_to_dir, + ) + result = qr.export_as_mot( + download_dir=tmp_path, annotation_field="ann", image_width=1280, image_height=720 + ) + + seqinfo = (result / "seqinfo.ini").read_text() + assert "1280" in seqinfo + assert "720" in seqinfo + + +def test_export_mot_no_annotations_raises(ds, tmp_path): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[]) + + qr = _make_qr(ds, [dp], ann_field="ann") + with pytest.raises(RuntimeError, match="No video annotations"): + qr.export_as_mot(download_dir=tmp_path, annotation_field="ann") + + +def test_export_mot_multiple_videos(ds, tmp_path, monkeypatch): + dps = [] + for i in range(2): + dp = Datapoint(datasource=ds, path=f"video_{i}.mp4", datapoint_id=i, metadata={}) + ann = _make_video_bbox(frame=i, track_id=i) + ann.filename = dp.path + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[ann]) + dps.append(dp) + + def _mock_download_files(self, target_dir, *args, **kwargs): + target_dir.mkdir(parents=True, exist_ok=True) + for i in range(2): + (target_dir / f"video_{i}.mp4").write_bytes(b"fake") + return target_dir + + def _mock_export_mot_sequences_to_dirs(video_annotations, context, labels_dir, video_files=None): + for i in range(2): + seq_dir = labels_dir / f"video_{i}" + seq_dir.mkdir(parents=True, exist_ok=True) + (seq_dir / "gt").mkdir(parents=True, exist_ok=True) + (seq_dir / "gt" / "gt.txt").write_text("") + (seq_dir / "gt" / "labels.txt").write_text("person\n") + return labels_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr( + "dagshub.data_engine.model.query_result.export_mot_sequences_to_dirs", + _mock_export_mot_sequences_to_dirs, + ) + qr = _make_qr(ds, dps, ann_field="ann") + result = qr.export_as_mot(download_dir=tmp_path, annotation_field="ann") + + assert result == tmp_path / "labels" + assert (result / "video_0" / "gt" / "gt.txt").exists() + assert (result / "video_1" / "gt" / "gt.txt").exists() + + +def test_export_mot_passes_video_file_when_dimensions_missing(ds, tmp_path, monkeypatch): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + anns = [_make_video_bbox(frame=0, track_id=1), _make_video_bbox(frame=1, track_id=1)] + for ann in anns: + ann.image_width = 0 + ann.image_height = 0 + ann.filename = "video.mp4" + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=anns) + qr = _make_qr(ds, [dp], ann_field="ann") + + captured = {} + + def _mock_download_files(self, target_dir, *args, **kwargs): + video_path = target_dir / "video.mp4" + video_path.parent.mkdir(parents=True, exist_ok=True) + video_path.write_bytes(b"video") + return target_dir + + def _mock_export_mot_to_dir(video_annotations, context, output_dir, video_file=None): + captured["video_file"] = str(video_file) if video_file is not None else None + output_dir.mkdir(parents=True, exist_ok=True) + return output_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr("dagshub.data_engine.model.query_result.export_mot_to_dir", _mock_export_mot_to_dir) + + qr.export_as_mot(download_dir=tmp_path, annotation_field="ann") + + assert captured["video_file"] is not None + assert captured["video_file"].endswith("video.mp4") + + +# --- helpers --- + + +def _make_video_bbox(frame=0, track_id=0) -> IRVideoBBoxAnnotation: + return IRVideoBBoxAnnotation( + track_id=track_id, frame_number=frame, + left=100.0, top=150.0, width=50.0, height=80.0, + image_width=1920, image_height=1080, + categories={"person": 1.0}, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + + +def _create_mot_dir(mot_dir: Path): + gt_dir = mot_dir / "gt" + gt_dir.mkdir(parents=True) + (gt_dir / "gt.txt").write_text("1,1,100,150,50,80,1,1,1.0\n2,1,110,160,50,80,1,1,0.9\n") + (gt_dir / "labels.txt").write_text("person\n") + config = configparser.ConfigParser() + config["Sequence"] = { + "name": "test", "frameRate": "30", "seqLength": "100", + "imWidth": "1920", "imHeight": "1080", + } + with open(mot_dir / "seqinfo.ini", "w") as f: + config.write(f) + + +def _zip_mot_dir(tmp_path: Path, mot_dir: Path) -> Path: + zip_path = tmp_path / "mot.zip" + with zipfile.ZipFile(zip_path, "w") as z: + z.write(mot_dir / "gt" / "gt.txt", "gt/gt.txt") + z.write(mot_dir / "gt" / "labels.txt", "gt/labels.txt") + z.write(mot_dir / "seqinfo.ini", "seqinfo.ini") + return zip_path + + +def _make_video_qr(ds): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + anns = [_make_video_bbox(frame=0, track_id=1), _make_video_bbox(frame=1, track_id=1)] + for ann in anns: + ann.filename = "video.mp4" + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=anns) + qr = _make_qr(ds, [dp], ann_field="ann") + return qr, dp + + +def _make_qr(ds, datapoints, ann_field=None): + fields = [] + if ann_field: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=ann_field, + multiple=False, valueType=MetadataFieldType.BLOB, + name=ann_field, tags={ReservedTags.ANNOTATION.value}, + )) + return QueryResult(datasource=ds, _entries=datapoints, fields=fields) diff --git a/tests/data_engine/conftest.py b/tests/data_engine/conftest.py index e8f0c70a..e57d1e83 100644 --- a/tests/data_engine/conftest.py +++ b/tests/data_engine/conftest.py @@ -5,7 +5,7 @@ from dagshub.common.api import UserAPI from dagshub.common.api.responses import UserAPIResponse from dagshub.data_engine import datasources -from dagshub.data_engine.client.models import MetadataSelectFieldSchema, PreprocessingStatus +from dagshub.data_engine.client.models import DatasourceType, MetadataSelectFieldSchema, PreprocessingStatus from dagshub.data_engine.model.datapoint import Datapoint from dagshub.data_engine.model.datasource import DatasetState, Datasource from dagshub.data_engine.model.query_result import QueryResult @@ -26,6 +26,7 @@ def other_ds(mocker, mock_dagshub_auth) -> Datasource: def _create_mock_datasource(mocker, id, name) -> Datasource: ds_state = datasources.DatasourceState(id=id, name=name, repo="kirill/repo") + ds_state.source_type = DatasourceType.REPOSITORY ds_state.path = "repo://kirill/repo/data/" ds_state.preprocessing_status = PreprocessingStatus.READY mocker.patch.object(ds_state, "client") diff --git a/tests/mocks/repo_api.py b/tests/mocks/repo_api.py index d457d161..22b6c94c 100644 --- a/tests/mocks/repo_api.py +++ b/tests/mocks/repo_api.py @@ -113,6 +113,10 @@ def generate_content_api_entry(path, is_dir=False, versioning="dvc") -> ContentA def default_branch(self) -> str: return self._default_branch + @property + def id(self) -> int: + return 1 + def get_connected_storages(self) -> List[StorageAPIEntry]: return self.storages