Skip to content

Commit 6f5b720

Browse files
committed
Restructured tests and shemas
1 parent a0f8911 commit 6f5b720

File tree

15 files changed

+622
-38
lines changed

15 files changed

+622
-38
lines changed

src/superannotate_databricks_connector/schemas/comment.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
FloatType,
66
BooleanType,
77
MapType,
8-
ArrayType
8+
ArrayType,
9+
IntegerType
910
)
1011

12+
from .shapes import get_bbox_schema
1113

12-
def get_comment_schema():
13-
comment_schema = StructType([
14+
15+
def get_vector_comment_schema():
16+
return StructType([
1417
StructField("correspondence",
1518
ArrayType(MapType(
1619
StringType(),
@@ -23,12 +26,53 @@ def get_comment_schema():
2326
StructField("createdBy", MapType(
2427
StringType(),
2528
StringType()),
29+
True),
30+
StructField("creationType", StringType(), True),
31+
StructField("updatedAt", StringType(), True),
32+
StructField("updatedBy", MapType(
33+
StringType(),
34+
StringType()),
35+
True)
36+
])
37+
38+
39+
def get_video_timestamp_schema():
40+
return StructType([
41+
StructField("timestamp", IntegerType(), True),
42+
StructField("points", get_bbox_schema(), True)
43+
])
44+
45+
46+
def get_video_comment_parameter_schema():
47+
return StructType([
48+
StructField("start", IntegerType(), True),
49+
StructField("end", IntegerType, True),
50+
StructField("timestamps", ArrayType(
51+
get_video_timestamp_schema()), True)
52+
])
53+
54+
55+
def get_video_comment_schema():
56+
return StructType([
57+
StructField("correspondence",
58+
ArrayType(MapType(
59+
StringType(),
60+
StringType())),
2661
True),
62+
StructField("start", IntegerType(), True),
63+
StructField("end", IntegerType(), True),
64+
StructField("createdAt", StringType(), True),
65+
StructField("createdBy", MapType(
66+
StringType(),
67+
StringType()),
68+
True),
2769
StructField("creationType", StringType(), True),
2870
StructField("updatedAt", StringType(), True),
2971
StructField("updatedBy", MapType(
3072
StringType(),
3173
StringType()),
32-
True)
74+
True),
75+
StructField("parameters",
76+
ArrayType(get_video_comment_parameter_schema()), True)
77+
3378
])
34-
return comment_schema
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from pyspark.sql.types import (
2+
StructType,
3+
StructField,
4+
FloatType,
5+
ArrayType
6+
)
7+
8+
9+
def get_bbox_schema():
10+
"""
11+
Defines the schema of a bounding box
12+
13+
Args:
14+
None
15+
16+
Returns:
17+
StructType: Schema of a bbox
18+
"""
19+
return StructType([
20+
StructField("x1", FloatType(), True),
21+
StructField("y1", FloatType(), True),
22+
StructField("x2", FloatType(), True),
23+
StructField("y2", FloatType(), True)
24+
])
25+
26+
27+
def get_rbbox_schema():
28+
"""
29+
Defines the schema of a rotated bounding box
30+
this contains one point for each corned
31+
32+
Args:
33+
None
34+
35+
Returns:
36+
StructType: Schema of a bbox
37+
"""
38+
return StructType([
39+
StructField("x1", FloatType(), True),
40+
StructField("y1", FloatType(), True),
41+
StructField("x2", FloatType(), True),
42+
StructField("y2", FloatType(), True),
43+
StructField("x3", FloatType(), True),
44+
StructField("y3", FloatType(), True),
45+
StructField("x4", FloatType(), True),
46+
StructField("y5", FloatType(), True)
47+
])
48+
49+
50+
def get_point_schema():
51+
"""
52+
Defines the schema of a point
53+
54+
Args:
55+
None
56+
57+
Returns:
58+
StructType: Schema of a point
59+
"""
60+
return StructType([
61+
StructField("x", FloatType(), True),
62+
StructField("y", FloatType(), True)
63+
])
64+
65+
66+
def get_cuboid_schema():
67+
"""
68+
Defines the schema of a cuboid (3d bounding box)
69+
70+
Args:
71+
None
72+
73+
Returns:
74+
StructType: Schema of a cuboid
75+
"""
76+
return StructType([
77+
StructField("f1", get_point_schema(), True),
78+
StructField("f2", get_point_schema(), True),
79+
StructField("r1", get_point_schema(), True),
80+
StructField("r2", get_point_schema(), True)
81+
])
82+
83+
84+
def get_ellipse_schema():
85+
"""
86+
Defines the schema of an ellipse
87+
88+
Args:
89+
None
90+
91+
Returns:
92+
StructType: Schema of an ellipse
93+
"""
94+
return StructType([
95+
StructField("cx", FloatType(), True),
96+
StructField("cy", FloatType(), True),
97+
StructField("rx", FloatType(), True),
98+
StructField("ty", FloatType(), True),
99+
StructField("angle", FloatType(), True)
100+
])
101+
102+
103+
def get_polygon_schema():
104+
"""
105+
Defines the schema of a polygon. It contains a shell as well
106+
as excluded points
107+
108+
Args:
109+
None
110+
111+
Returns:
112+
StructType: Schema of a polygon with holes
113+
"""
114+
return StructType([
115+
StructField("points", ArrayType(FloatType()), True),
116+
StructField("exclude", ArrayType(ArrayType(FloatType)), True)
117+
])
118+
119+
120+
def get_polyline_schema():
121+
"""
122+
Defines the schema of a polyline
123+
A simple array of float
124+
125+
Args:
126+
None
127+
128+
Returns:
129+
ArrayType: Schema of a polygon with holes
130+
"""
131+
return ArrayType(FloatType())
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from pyspark.sql.types import (
2+
StructType,
3+
StructField,
4+
StringType,
5+
IntegerType,
6+
MapType,
7+
ArrayType,
8+
)
9+
10+
11+
def get_tag_schema():
12+
schema = StructType([
13+
StructField("instance_type", StringType(), True),
14+
StructField("classId", IntegerType(), True),
15+
StructField("probability", IntegerType(), True),
16+
StructField("attributes", ArrayType(MapType(StringType(),
17+
StringType())),
18+
True),
19+
StructField("createdAt", StringType(), True),
20+
StructField("createdBy", MapType(StringType(), StringType()), True),
21+
StructField("creationType", StringType(), True),
22+
StructField("updatedAt", StringType(), True),
23+
StructField("updatedBy", MapType(StringType(), StringType()), True),
24+
StructField("className", StringType(), True)])
25+
return schema
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from pyspark.sql.types import (
2+
StructType,
3+
StructField,
4+
StringType,
5+
IntegerType,
6+
ArrayType,
7+
MapType
8+
)
9+
10+
from .shapes import (
11+
get_point_schema,
12+
get_bbox_schema,
13+
get_polygon_schema,
14+
get_polyline_schema,
15+
)
16+
from .comment import get_video_comment_schema
17+
from .tag import get_tag_schema
18+
19+
20+
def get_instance_metadata_schema():
21+
return StructType([
22+
StructField("type", StringType(), True),
23+
StructField("className", StringType(), True),
24+
StructField("start", IntegerType(), True),
25+
StructField("end", IntegerType(), True),
26+
StructField("creationType", StringType, True)
27+
])
28+
29+
30+
def get_timestamp_schema():
31+
return StructType([
32+
StructField("timestamp", IntegerType(), True),
33+
StructField("bbox", get_bbox_schema(), True),
34+
StructField("polygon", get_polygon_schema()),
35+
StructField("polyline", get_polyline_schema(), True),
36+
StructField("point", get_point_schema(), True),
37+
StructField("attributes", ArrayType(MapType(StringType(),
38+
StringType())),
39+
True)
40+
])
41+
42+
43+
def get_instance_parameter_schema():
44+
return StructType([
45+
StructField("start", IntegerType(), True),
46+
StructField("end", IntegerType(), True),
47+
StructField("timestamp", get_timestamp_schema(), True)
48+
])
49+
50+
51+
def get_instance_schema():
52+
StructType([
53+
StructField("meta", get_instance_metadata_schema(), True),
54+
StructField("parameters", get_instance_parameter_schema(), True)
55+
])
56+
57+
58+
def get_video_schema():
59+
return StructType([
60+
StructField("video_height", IntegerType(), True),
61+
StructField("video_width", IntegerType(), True),
62+
StructField("video_name", StringType(), True),
63+
StructField("url", StringType(), True),
64+
StructField("projectId", IntegerType(), True),
65+
StructField("duration", IntegerType(), True),
66+
StructField("status", StringType(), True),
67+
StructField("annotatorEmail", StringType(), True),
68+
StructField("qaEmail", StringType(), True),
69+
StructField("instances", ArrayType(get_instance_schema()),
70+
True),
71+
StructField("comments", ArrayType(get_video_comment_schema()), True),
72+
StructField("tags", ArrayType(get_tag_schema()), True)
73+
])

src/superannotate_databricks_connector/vector.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,29 @@ def process_vector_object(instance, custom_id_map=None):
2424
'classId': instance["classId"] if custom_id_map is None
2525
else custom_id_map.get(instance["className"]),
2626
'probability': instance.get('probability'),
27-
'bbox_points': {k: float(v) for k, v in instance['points'].items()}
27+
'bbox': {k: float(v) for k, v in instance['points'].items()}
2828
if instance["type"] == "bbox" else None,
29-
'polygon_points': [float(p) for p in instance['points']]
29+
'rbbox': {k: float(v) for k, v in instance['points'].items()}
30+
if instance["type"] == "rbbox" else None,
31+
'polygon': {"points": [float(p) for p in instance['points']]
32+
if instance["type"] == "polygon" else None,
33+
'exclude': instance.get("exclude")}
3034
if instance["type"] == "polygon" else None,
31-
'polygon_exclude': instance["exclude"]
32-
if instance["type"] == "polygon" else None,
33-
'point_points': {"x": float(instance["x"]),
34-
"y": float(instance["y"])
35-
} if instance["type"] == "point" else None,
36-
'ellipse_points': {"cx": float(instance["cx"]),
37-
"cy": float(instance["cy"]),
38-
"rx": float(instance["rx"]),
39-
"ry": float(instance["ry"]),
40-
"angle": float(instance["angle"])}
35+
'point': {"x": float(instance.get("x")),
36+
"y": float(instance.get("y"))
37+
} if instance["type"] == "point" else None,
38+
'ellipse': {"cx": float(instance["cx"]),
39+
"cy": float(instance["cy"]),
40+
"rx": float(instance["rx"]),
41+
"ry": float(instance["ry"]),
42+
"angle": float(instance["angle"])}
4143
if instance["type"] == "ellipse" else None,
42-
'cuboid_points': {outer_k: {inner_k: float(inner_v)
43-
for inner_k, inner_v in outer_v.items()}
44-
for outer_k, outer_v in instance['points'].items()}
44+
'cuboid': {outer_k: {inner_k: float(inner_v)
45+
for inner_k, inner_v in outer_v.items()}
46+
for outer_k, outer_v in instance['points'].items()}
4547
if instance["type"] == "cuboid" else None,
48+
"polyline": [float(p) for p in instance['points']]
49+
if instance["type"] == "polyline" else None,
4650
'groupId': instance.get('groupId'),
4751
'locked': instance.get('locked'),
4852
'attributes': instance['attributes'],
@@ -150,7 +154,9 @@ def get_vector_dataframe(annotations, spark, custom_id_map=None):
150154
"comments": [process_comment(comment)
151155
for comment in item["comments"]]
152156
}
157+
153158
rows.append(flattened_item)
154159
schema = get_vector_schema()
160+
155161
spark_df = spark.createDataFrame(rows, schema=schema)
156162
return spark_df

0 commit comments

Comments
 (0)