|
2 | 2 | from typing import Optional |
3 | 3 | from typing import Union |
4 | 4 |
|
| 5 | +from pydantic import BaseModel |
| 6 | +from pydantic import Field |
| 7 | +from pydantic import StrictFloat |
| 8 | +from pydantic import StrictInt |
| 9 | +from pydantic import StrictStr |
| 10 | +from pydantic import ValidationError |
| 11 | +from pydantic import conlist |
| 12 | +from pydantic.error_wrappers import ErrorWrapper |
| 13 | + |
| 14 | +from superannotate_schemas.schemas.base import AxisPoint |
| 15 | +from superannotate_schemas.schemas.base import BaseAttribute |
| 16 | +from superannotate_schemas.schemas.base import BaseImageMetadata |
5 | 17 | from superannotate_schemas.schemas.base import BaseVectorInstance |
6 | 18 | from superannotate_schemas.schemas.base import BboxPoints |
7 | 19 | from superannotate_schemas.schemas.base import Comment |
8 | | -from superannotate_schemas.schemas.base import BaseImageMetadata |
| 20 | +from superannotate_schemas.schemas.base import INVALID_DICT_MESSAGE |
| 21 | +from superannotate_schemas.schemas.base import NotEmptyStr |
| 22 | +from superannotate_schemas.schemas.base import StrictNumber |
9 | 23 | from superannotate_schemas.schemas.base import Tag |
10 | | -from superannotate_schemas.schemas.base import BaseAttribute |
11 | | -from superannotate_schemas.schemas.base import AxisPoint |
12 | 24 | from superannotate_schemas.schemas.enums import VectorAnnotationTypeEnum |
13 | | -from superannotate_schemas.schemas.base import StrictNumber |
14 | | -from superannotate_schemas.schemas.base import NotEmptyStr |
15 | | - |
16 | | -from pydantic import BaseModel |
17 | | -from pydantic import StrictInt |
18 | | -from pydantic import StrictFloat |
19 | | -from pydantic import conlist |
20 | | -from pydantic import Field |
21 | | -from pydantic import StrictStr |
22 | | -from pydantic import validate_model |
23 | | -from pydantic import ValidationError |
24 | | -from pydantic import validator |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class Attribute(BaseAttribute): |
@@ -53,7 +53,7 @@ class Bbox(VectorInstance): |
53 | 53 | points: BboxPoints |
54 | 54 |
|
55 | 55 |
|
56 | | -class RotatedBoxPoints(VectorInstance): |
| 56 | +class RotatedBoxPoints(BaseModel): |
57 | 57 | x1: StrictNumber |
58 | 58 | y1: StrictNumber |
59 | 59 | x2: StrictNumber |
@@ -118,23 +118,43 @@ class Cuboid(VectorInstance): |
118 | 118 | } |
119 | 119 |
|
120 | 120 |
|
| 121 | +class AnnotationInstance(BaseModel): |
| 122 | + __root__: Union[ |
| 123 | + Template, Cuboid, Point, PolyLine, Polygon, Bbox, Ellipse, RotatedBox |
| 124 | + ] |
| 125 | + |
| 126 | + @classmethod |
| 127 | + def __get_validators__(cls): |
| 128 | + yield cls.return_action |
| 129 | + |
| 130 | + @classmethod |
| 131 | + def return_action(cls, values): |
| 132 | + try: |
| 133 | + try: |
| 134 | + instance_type = values["type"] |
| 135 | + except KeyError: |
| 136 | + raise ValidationError( |
| 137 | + [ErrorWrapper(ValueError("field required"), "type")], cls |
| 138 | + ) |
| 139 | + return ANNOTATION_TYPES[instance_type](**values) |
| 140 | + except KeyError: |
| 141 | + raise ValidationError( |
| 142 | + [ |
| 143 | + ErrorWrapper( |
| 144 | + ValueError( |
| 145 | + f"invalid type, valid types are {', '.join(ANNOTATION_TYPES.keys())}" |
| 146 | + ), |
| 147 | + "type", |
| 148 | + ) |
| 149 | + ], |
| 150 | + cls, |
| 151 | + ) |
| 152 | + except TypeError as e: |
| 153 | + raise TypeError(INVALID_DICT_MESSAGE) from e |
| 154 | + |
| 155 | + |
121 | 156 | class VectorAnnotation(BaseModel): |
122 | 157 | metadata: Metadata |
123 | 158 | comments: Optional[List[Comment]] = Field(list()) |
124 | 159 | tags: Optional[List[Tag]] = Field(list()) |
125 | | - instances: Optional[ |
126 | | - List[ |
127 | | - Union[Template, Cuboid, Point, PolyLine, Polygon, Bbox, Ellipse, RotatedBox] |
128 | | - ] |
129 | | - ] = Field(list()) |
130 | | - |
131 | | - @validator("instances", pre=True, each_item=True) |
132 | | - def check_instances(cls, instance): |
133 | | - # todo add type checking |
134 | | - annotation_type = instance.get("type") |
135 | | - result = validate_model(ANNOTATION_TYPES[annotation_type], instance) |
136 | | - if result[2]: |
137 | | - raise ValidationError( |
138 | | - result[2].raw_errors, model=ANNOTATION_TYPES[annotation_type] |
139 | | - ) |
140 | | - return instance |
| 160 | + instances: Optional[List[AnnotationInstance]] = Field(list()) |
0 commit comments