Skip to content

Commit 56eba72

Browse files
committed
Bugfixes. Missing instances and tags
1 parent f49cc7f commit 56eba72

File tree

5 files changed

+34
-10
lines changed

5 files changed

+34
-10
lines changed

src/superannotate_databricks_connector/schemas/text_schema.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def get_text_schema():
4242
schema = StructType([
4343
StructField("name", StringType(), True),
4444
StructField("url", StringType(), True),
45-
StructField("contentLength", IntegerType(), True),
4645
StructField("projectId", IntegerType(), True),
4746
StructField("status", StringType(), True),
4847
StructField("annotatorEmail", StringType(), True),

src/superannotate_databricks_connector/schemas/vector_schema.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ def get_vector_instance_schema():
5959
return instance_schema
6060

6161

62+
def get_vector_tag_schema():
63+
schema = StructType([
64+
StructField("instance_type", StringType(), True),
65+
StructField("classId", IntegerType(), True),
66+
StructField("probability", IntegerType(), True),
67+
StructField("attributes", ArrayType(MapType(StringType(),
68+
StringType())),
69+
True),
70+
StructField("createdAt", StringType(), True),
71+
StructField("createdBy", MapType(StringType(), StringType()), True),
72+
StructField("creationType", StringType(), True),
73+
StructField("updatedAt", StringType(), True),
74+
StructField("updatedBy", MapType(StringType(), StringType()), True),
75+
StructField("className", StringType(), True)])
76+
return schema
77+
78+
6279
def get_vector_schema():
6380
schema = StructType([
6481
StructField("image_height", IntegerType(), True),
@@ -73,6 +90,7 @@ def get_vector_schema():
7390
StructField("instances", ArrayType(get_vector_instance_schema()),
7491
True),
7592
StructField("bounding_boxes", ArrayType(IntegerType()), True),
76-
StructField("comments", ArrayType(get_comment_schema()), True)
93+
StructField("comments", ArrayType(get_comment_schema()), True),
94+
StructField("tags", ArrayType(get_vector_tag_schema()), True)
7795
])
7896
return schema

src/superannotate_databricks_connector/text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from datetime import datetime
2-
from superannotate_databricks_connector.schemas.text_schema import get_text_schema
2+
from superannotate_databricks_connector.schemas.text_schema import (
3+
get_text_schema
4+
)
35

46

57
def convert_dates(instance):
@@ -40,7 +42,6 @@ def get_text_dataframe(annotations, spark):
4042
flattened_item = {
4143
"name": item["metadata"]["name"],
4244
"url": item["metadata"]["url"],
43-
"contentLength": item["metadata"]["contentLength"],
4445
"projecId": item["metadata"]["projectId"],
4546
"status": item["metadata"]["status"],
4647
"annotatorEmail": item["metadata"]["annotatorEmail"],

src/superannotate_databricks_connector/vector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from superannotate_databricks_connector.schemas.vector_schema import get_vector_schema
1+
from superannotate_databricks_connector.schemas.vector_schema import (
2+
get_vector_schema
3+
)
24

35

46
def process_comment(comment):

tests/test_vector.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
class TestVectorInstances(unittest.TestCase):
1818
def __init__(self, *args):
1919
super().__init__(*args)
20-
with open(os.path.join(DATA_SET_PATH, "vector/example_annotation.json"), "r") as f:
20+
with open(os.path.join(DATA_SET_PATH,
21+
"vector/example_annotation.json"), "r") as f:
2122
data = json.load(f)
2223

2324
target_data = []
24-
with open(os.path.join(DATA_SET_PATH, 'vector/expected_instances.json'),"r") as f:
25+
with open(os.path.join(DATA_SET_PATH,
26+
'vector/expected_instances.json'), "r") as f:
2527
for line in f:
2628
target_data.append(json.loads(line))
2729

@@ -96,20 +98,22 @@ def test_get_boxes(self):
9698
"y1": 2.1,
9799
"y2": 18.9
98100
},
99-
"classId": 10229}]
101+
"classId": 10229}]
100102
target = [2, 1, 13, 22, 10228, 3, 2, 4, 19, 10229]
101103
self.assertEqual(get_boxes(instances), target)
102104

103105

104106
class TestVectorDataFrame(unittest.TestCase):
105107
def test_vector_dataframe(self):
106108
spark = SparkSession.builder.master("local").getOrCreate()
107-
with open(os.path.join(DATA_SET_PATH, "vector/example_annotation.json"),"r") as f:
109+
with open(os.path.join(DATA_SET_PATH,
110+
"vector/example_annotation.json"), "r") as f:
108111
data = json.load(f)
109112

110113
actual_df = get_vector_dataframe([data], spark)
111114

112-
expected_df = spark.read.parquet(os.path.join(DATA_SET_PATH, "vector/expected_df.parquet"))
115+
expected_df = spark.read.parquet(os.path.join(
116+
DATA_SET_PATH, "vector/expected_df.parquet"))
113117
self.assertEqual(sorted(actual_df.collect()),
114118
sorted(expected_df.collect()))
115119

0 commit comments

Comments
 (0)