|
2 | 2 |
|
3 | 3 | import requests |
4 | 4 |
|
5 | | -from .constants import METADATA_KEY, NAME_KEY, REFERENCE_ID_KEY |
| 5 | +from .constants import METADATA_KEY, MODEL_TAGS_KEY, NAME_KEY, REFERENCE_ID_KEY |
6 | 6 | from .dataset import Dataset |
7 | 7 | from .job import AsyncJob |
8 | 8 | from .model_run import ModelRun |
@@ -93,13 +93,21 @@ class Model: |
93 | 93 | """ |
94 | 94 |
|
95 | 95 | def __init__( |
96 | | - self, model_id, name, reference_id, metadata, client, bundle_name=None |
| 96 | + self, |
| 97 | + model_id, |
| 98 | + name, |
| 99 | + reference_id, |
| 100 | + metadata, |
| 101 | + client, |
| 102 | + bundle_name=None, |
| 103 | + tags: List[str] = None, |
97 | 104 | ): |
98 | 105 | self.id = model_id |
99 | 106 | self.name = name |
100 | 107 | self.reference_id = reference_id |
101 | 108 | self.metadata = metadata |
102 | 109 | self.bundle_name = bundle_name |
| 110 | + self.tags = tags if tags else [] |
103 | 111 | self._client = client |
104 | 112 |
|
105 | 113 | def __repr__(self): |
@@ -213,3 +221,49 @@ def run(self, dataset_id: str, slice_id: Optional[str]) -> str: |
213 | 221 | ) |
214 | 222 |
|
215 | 223 | return response |
| 224 | + |
| 225 | + def add_tags(self, tags: List[str]): |
| 226 | + """Tag the model with custom tag names. :: |
| 227 | +
|
| 228 | + import nucleus |
| 229 | + client = nucleus.NucleusClient("YOUR_SCALE_API_KEY") |
| 230 | + model = client.list_models()[0] |
| 231 | +
|
| 232 | + model.add_tags(["tag_A", "tag_B"]) |
| 233 | +
|
| 234 | + Args: |
| 235 | + tags: list of tag names |
| 236 | + """ |
| 237 | + response = self._client.make_request( |
| 238 | + {MODEL_TAGS_KEY: tags}, |
| 239 | + f"model/{self.id}/tag", |
| 240 | + requests_command=requests.post, |
| 241 | + ) |
| 242 | + |
| 243 | + if response.get("msg", False): |
| 244 | + self.tags.extend(tags) |
| 245 | + |
| 246 | + return response |
| 247 | + |
| 248 | + def remove_tags(self, tags: List[str]): |
| 249 | + """Remove tag(s) from the model. :: |
| 250 | +
|
| 251 | + import nucleus |
| 252 | + client = nucleus.NucleusClient("YOUR_SCALE_API_KEY") |
| 253 | + model = client.list_models()[0] |
| 254 | +
|
| 255 | + model.remove_tags(["tag_x"]) |
| 256 | +
|
| 257 | + Args: |
| 258 | + tags: list of tag names to remove |
| 259 | + """ |
| 260 | + response = self._client.make_request( |
| 261 | + {MODEL_TAGS_KEY: tags}, |
| 262 | + f"model/{self.id}/tag", |
| 263 | + requests_command=requests.delete, |
| 264 | + ) |
| 265 | + |
| 266 | + if response.get("msg", False): |
| 267 | + self.tags = list(filter(lambda t: t not in tags, self.tags)) |
| 268 | + |
| 269 | + return response |
0 commit comments