Skip to content

Commit d48b0b5

Browse files
authored
[SDK-546] Must pass model run id in order for model slices to work correctly (#1387)
2 parents 2f980da + db70a40 commit d48b0b5

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

labelbox/schema/slice.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,10 @@ class ModelSlice(Slice):
239239
@classmethod
240240
def query_str(cls):
241241
query_str = """
242-
query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $from: DataRowIdentifierCursorInput, $first: Int!) {
242+
query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $modelRunId: ID, $from: DataRowIdentifierCursorInput, $first: Int!) {
243243
getDataRowIdentifiersBySavedModelQuery(input: {
244244
savedQueryId: $id,
245+
modelRunId: $modelRunId,
245246
after: $from
246247
first: $first
247248
}) {
@@ -263,17 +264,23 @@ def query_str(cls):
263264
"""
264265
return query_str
265266

266-
def get_data_row_ids(self) -> PaginatedCollection:
267+
def get_data_row_ids(self, model_run_id: str) -> PaginatedCollection:
267268
"""
268269
Fetches all data row ids that match this Slice
269270
271+
Params
272+
model_run_id: str, required, uid or cuid of model run
273+
270274
Returns:
271275
A PaginatedCollection of data row ids
272276
"""
273277
return PaginatedCollection(
274278
client=self.client,
275279
query=ModelSlice.query_str(),
276-
params={'id': str(self.uid)},
280+
params={
281+
'id': str(self.uid),
282+
'modelRunId': model_run_id
283+
},
277284
dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'],
278285
obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id'
279286
),
@@ -282,17 +289,24 @@ def get_data_row_ids(self) -> PaginatedCollection:
282289
'endCursor'
283290
])
284291

285-
def get_data_row_identifiers(self) -> PaginatedCollection:
292+
def get_data_row_identifiers(self,
293+
model_run_id: str) -> PaginatedCollection:
286294
"""
287295
Fetches all data row ids and global keys (where defined) that match this Slice
288296
297+
Params:
298+
model_run_id : str, required, uid or cuid of model run
299+
289300
Returns:
290301
A PaginatedCollection of Slice.DataRowIdAndGlobalKey
291302
"""
292303
return PaginatedCollection(
293304
client=self.client,
294305
query=ModelSlice.query_str(),
295-
params={'id': str(self.uid)},
306+
params={
307+
'id': str(self.uid),
308+
'modelRunId': model_run_id
309+
},
296310
dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'],
297311
obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey(
298312
data_row_id_and_gk.get('id'),

0 commit comments

Comments
 (0)