33import requests
44
55from .async_job import AsyncJob
6- from .constants import METADATA_KEY , MODEL_TAGS_KEY , NAME_KEY , REFERENCE_ID_KEY
6+ from .constants import (
7+ METADATA_KEY ,
8+ MODEL_TAGS_KEY ,
9+ MODEL_TRAINED_SLICE_IDS_KEY ,
10+ NAME_KEY ,
11+ REFERENCE_ID_KEY ,
12+ )
713from .dataset import Dataset
814from .model_run import ModelRun
915from .prediction import (
@@ -101,6 +107,7 @@ def __init__(
101107 client ,
102108 bundle_name = None ,
103109 tags = None ,
110+ trained_slice_ids = None ,
104111 ):
105112 self .id = model_id
106113 self .name = name
@@ -109,9 +116,10 @@ def __init__(
109116 self .bundle_name = bundle_name
110117 self .tags = tags if tags else []
111118 self ._client = client
119+ self .trained_slice_ids = trained_slice_ids if trained_slice_ids else []
112120
113121 def __repr__ (self ):
114- return f"Model(model_id='{ self .id } ', name='{ self .name } ', reference_id='{ self .reference_id } ', metadata={ self .metadata } , bundle_name={ self .bundle_name } , tags={ self .tags } , client={ self ._client } )"
122+ return f"Model(model_id='{ self .id } ', name='{ self .name } ', reference_id='{ self .reference_id } ', metadata={ self .metadata } , bundle_name={ self .bundle_name } , tags={ self .tags } , client={ self ._client } , trained_slice_ids= { self . trained_slice_ids } )"
115123
116124 def __eq__ (self , other ):
117125 return (
@@ -120,6 +128,7 @@ def __eq__(self, other):
120128 and (self .metadata == other .metadata )
121129 and (self ._client == other ._client )
122130 and (self .bundle_name == other .bundle_name )
131+ and (self .trained_slice_ids == other .trained_slice_ids )
123132 )
124133
125134 def __hash__ (self ):
@@ -134,6 +143,8 @@ def from_json(cls, payload: dict, client):
134143 reference_id = payload ["ref_id" ],
135144 metadata = payload ["metadata" ] or None ,
136145 client = client ,
146+ tags = payload .get (MODEL_TAGS_KEY , None ),
147+ trained_slice_ids = payload .get (MODEL_TRAINED_SLICE_IDS_KEY , None ),
137148 )
138149
139150 def create_run (
@@ -242,7 +253,9 @@ def add_tags(self, tags: List[str]):
242253 )
243254
244255 if response .ok :
245- self .tags .extend (tags )
256+ for tag in tags :
257+ if tag not in self .tags :
258+ self .tags .append (tag )
246259
247260 return response .json ()
248261
@@ -269,3 +282,55 @@ def remove_tags(self, tags: List[str]):
269282 self .tags = list (filter (lambda t : t not in tags , self .tags ))
270283
271284 return response .json ()
285+
286+ def add_trained_slice_ids (self , slice_ids : List [str ]):
287+ """Add trained slice id(s) to the model. ::
288+
289+ import nucleus
290+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
291+ model = client.list_models()[0]
292+
293+ model.add_trained_slice_ids(["slc_...", "slc_..."])
294+
295+ Args:
296+ slice_ids: list of trained slice ids
297+ """
298+ response : requests .Response = self ._client .make_request (
299+ {MODEL_TRAINED_SLICE_IDS_KEY : slice_ids },
300+ f"model/{ self .id } /trainedSliceId" ,
301+ requests_command = requests .post ,
302+ return_raw_response = True ,
303+ )
304+
305+ if response .ok :
306+ for slice_id in slice_ids :
307+ if slice_id not in self .trained_slice_ids :
308+ self .trained_slice_ids .append (slice_id )
309+
310+ return response .json ()
311+
312+ def remove_trained_slice_ids (self , slide_ids : List [str ]):
313+ """Remove trained slice id(s) from the model. ::
314+
315+ import nucleus
316+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
317+ model = client.list_models()[0]
318+
319+ model.remove_trained_slice_ids(["slc_...", "slc_..."])
320+
321+ Args:
322+ slice_ids: list of trained slice ids to remove
323+ """
324+ response : requests .Response = self ._client .make_request (
325+ {MODEL_TRAINED_SLICE_IDS_KEY : slide_ids },
326+ f"model/{ self .id } /trainedSliceId" ,
327+ requests_command = requests .delete ,
328+ return_raw_response = True ,
329+ )
330+
331+ if response .ok :
332+ self .trained_slice_ids = list (
333+ filter (lambda t : t not in slide_ids , self .trained_slice_ids )
334+ )
335+
336+ return response .json ()
0 commit comments