Skip to content

Commit 5b4b4bf

Browse files
authored
Merge pull request #656 from Labelbox/mmw/create-model-run-with-config
Add model run config on create [AL-3060]
2 parents f1bd083 + 88fd270 commit 5b4b4bf

File tree

5 files changed

+25
-7
lines changed

5 files changed

+25
-7
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,7 @@ a `Label`. Default value is 0.0.
654654

655655
## Version 2.2 (2019-10-18)
656656
Changelog not maintained before version 2.2.
657+
658+
### Changed
659+
* `Model.create_model_run()`
660+
* Add training metadata config as a model run creation param

examples/model_diagnostics/model_diagnostics_guide.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@
350350
"source": [
351351
"lb_model = client.create_model(name=f\"{project.name}-model\",\n",
352352
" ontology_id=project.ontology().uid)\n",
353-
"lb_model_run = lb_model.create_model_run(\"0.0.0\")\n",
353+
"lb_model_run_hyperparameters = {\"batch_size\": 1000}\n",
354+
"lb_model_run = lb_model.create_model_run(\"0.0.0\", lb_model_run_hyperparameters)\n",
354355
"lb_model_run.upsert_labels([label.uid for label in labels])"
355356
]
356357
},

labelbox/schema/model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@ class Model(DbObject):
1818
name = Field.String("name")
1919
model_runs = Relationship.ToMany("ModelRun", False)
2020

21-
def create_model_run(self, name) -> "ModelRun":
21+
def create_model_run(self, name, config=None) -> "ModelRun":
2222
""" Creates a model run belonging to this model.
2323
2424
Args:
2525
name (string): The name for the model run.
26+
config (json): Model run's training metadata config
2627
Returns:
2728
ModelRun, the created model run.
2829
"""
2930
name_param = "name"
31+
config_param = "config"
3032
model_id_param = "modelId"
3133
ModelRun = Entity.ModelRun
32-
query_str = """mutation CreateModelRunPyApi($%s: String!, $%s: ID!) {
33-
createModelRun(data: {name: $%s, modelId: $%s}) {%s}}""" % (
34-
name_param, model_id_param, name_param, model_id_param,
35-
query.results_query_part(ModelRun))
34+
query_str = """mutation CreateModelRunPyApi($%s: String!, $%s: Json, $%s: ID!) {
35+
createModelRun(data: {name: $%s, trainingMetadata: $%s, modelId: $%s}) {%s}}""" % (
36+
name_param, config_param, model_id_param, name_param, config_param,
37+
model_id_param, query.results_query_part(ModelRun))
3638
res = self.client.execute(query_str, {
3739
name_param: name,
40+
config_param: config,
3841
model_id_param: self.uid
3942
})
4043
return ModelRun(self.client, res["createModelRun"])

labelbox/schema/model_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ModelRun(DbObject):
3232
created_at = Field.DateTime("created_at")
3333
created_by_id = Field.String("created_by_id", "createdBy")
3434
model_id = Field.String("model_id")
35+
training_metadata = Field.Json("training_metadata")
3536

3637
class Status(Enum):
3738
EXPORTING_DATA = "EXPORTING_DATA"

tests/integration/annotation_import/test_model_run.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ def test_model_run(client, configured_project_with_label, rand_gen):
1414
model = client.create_model(data["name"], data["ontology_id"])
1515

1616
name = rand_gen(str)
17-
model_run = model.create_model_run(name)
17+
config = {"batch_size": 100, "reruns": None}
18+
model_run = model.create_model_run(name, config)
1819
assert model_run.name == name
20+
assert model_run.training_metadata["batchSize"] == config["batch_size"]
21+
assert model_run.training_metadata["reruns"] == config["reruns"]
1922
assert model_run.model_id == model.uid
2023
assert model_run.created_by_id == client.get_user().uid
2124

@@ -32,6 +35,12 @@ def test_model_run(client, configured_project_with_label, rand_gen):
3235
assert fetch_model_run == model_run
3336

3437

38+
def test_model_run_no_config(rand_gen, model):
39+
name = rand_gen(str)
40+
model_run = model.create_model_run(name)
41+
assert model_run.name == name
42+
43+
3544
def test_model_run_delete(client, model_run):
3645
models_before = list(client.get_models())
3746
model_before = models_before[0]

0 commit comments

Comments
 (0)