1010from .constants import (
1111 BACKEND_REFERENCE_ID_KEY ,
1212 CAMERA_PARAMS_KEY ,
13+ EMBEDDING_INFO_KEY ,
14+ EMBEDDING_VECTOR_KEY ,
1315 IMAGE_URL_KEY ,
16+ INDEX_ID_KEY ,
1417 METADATA_KEY ,
1518 ORIGINAL_IMAGE_URL_KEY ,
1619 POINTCLOUD_URL_KEY ,
@@ -26,6 +29,18 @@ class DatasetItemType(Enum):
2629 POINTCLOUD = "pointcloud"
2730
2831
32+ @dataclass
33+ class DatasetItemEmbeddingInfo :
34+ index_id : str
35+ embedding_vector : list
36+
37+ def to_payload (self ) -> dict :
38+ return {
39+ INDEX_ID_KEY : self .index_id ,
40+ EMBEDDING_VECTOR_KEY : self .embedding_vector ,
41+ }
42+
43+
2944@dataclass # pylint: disable=R0902
3045class DatasetItem : # pylint: disable=R0902
3146 """A dataset item is an image or pointcloud that has associated metadata.
@@ -113,16 +128,23 @@ class DatasetItem: # pylint: disable=R0902
113128 metadata : Optional [dict ] = None
114129 pointcloud_location : Optional [str ] = None
115130 upload_to_scale : Optional [bool ] = True
131+ embedding_info : Optional [DatasetItemEmbeddingInfo ] = None
116132
117133 def __post_init__ (self ):
118134 assert self .reference_id != "DUMMY_VALUE" , "reference_id is required."
119135 assert bool (self .image_location ) != bool (
120136 self .pointcloud_location
121137 ), "Must specify exactly one of the image_location or pointcloud_location parameters"
138+ if self .pointcloud_location and self .embedding_info :
139+ raise AssertionError (
140+ "Cannot upload embedding vector if pointcloud_location is set"
141+ )
142+
122143 if (self .pointcloud_location ) and not self .upload_to_scale :
123144 raise NotImplementedError (
124145 "Skipping upload to Scale is not currently implemented for pointclouds."
125146 )
147+
126148 self .local = (
127149 is_local_path (self .image_location ) if self .image_location else None
128150 )
@@ -179,6 +201,9 @@ def to_payload(self, is_scene=False) -> dict:
179201
180202 payload [REFERENCE_ID_KEY ] = self .reference_id
181203
204+ if self .embedding_info :
205+ payload [EMBEDDING_INFO_KEY ] = self .embedding_info .to_payload ()
206+
182207 if is_scene :
183208 if self .image_location :
184209 payload [URL_KEY ] = self .image_location
0 commit comments