|
3 | 3 | import pytest |
4 | 4 | import requests |
5 | 5 |
|
6 | | -from nucleus import BoxAnnotation, Dataset, NucleusClient, Slice |
7 | | -from nucleus.constants import ANNOTATIONS_KEY, BOX_TYPE, ITEM_KEY |
| 6 | +from nucleus import BoxAnnotation, BoxPrediction, Dataset, NucleusClient, Slice |
| 7 | +from nucleus.constants import ( |
| 8 | + ANNOTATIONS_KEY, |
| 9 | + BOX_TYPE, |
| 10 | + ITEM_KEY, |
| 11 | + PREDICTIONS_KEY, |
| 12 | +) |
8 | 13 | from nucleus.job import AsyncJob |
9 | 14 |
|
10 | 15 | from .helpers import ( |
11 | 16 | TEST_BOX_ANNOTATIONS, |
| 17 | + TEST_BOX_PREDICTIONS, |
12 | 18 | TEST_PROJECT_ID, |
13 | 19 | TEST_SLICE_NAME, |
14 | 20 | get_uuid, |
15 | 21 | ) |
16 | 22 |
|
17 | 23 |
|
| 24 | +@pytest.fixture() |
| 25 | +def slc(CLIENT, dataset): |
| 26 | + slice_ref_ids = [item.reference_id for item in dataset.items[:1]] |
| 27 | + # Slice creation |
| 28 | + slc = dataset.create_slice( |
| 29 | + name=TEST_SLICE_NAME, |
| 30 | + reference_ids=slice_ref_ids, |
| 31 | + ) |
| 32 | + |
| 33 | + yield slc |
| 34 | + |
| 35 | + CLIENT.delete_slice(slc.id) |
| 36 | + |
| 37 | + |
18 | 38 | def test_reprs(): |
19 | 39 | # Have to define here in order to have access to all relevant objects |
20 | 40 | def test_repr(test_object: any): |
@@ -89,6 +109,40 @@ def get_expected_item(reference_id): |
89 | 109 | ] == get_expected_box_annotation(reference_id) |
90 | 110 |
|
91 | 111 |
|
| 112 | +def test_slice_create_and_prediction_export(dataset, slc, model): |
| 113 | + # Dataset upload |
| 114 | + ds_items = dataset.items |
| 115 | + |
| 116 | + predictions = [ |
| 117 | + BoxPrediction(**pred_raw) for pred_raw in TEST_BOX_PREDICTIONS |
| 118 | + ] |
| 119 | + response = dataset.upload_predictions(model, predictions) |
| 120 | + |
| 121 | + assert response |
| 122 | + |
| 123 | + slice_reference_ids = [item.reference_id for item in slc.items] |
| 124 | + |
| 125 | + def get_expected_box_prediction(reference_id): |
| 126 | + for prediction in predictions: |
| 127 | + if prediction.reference_id == reference_id: |
| 128 | + return prediction |
| 129 | + |
| 130 | + def get_expected_item(reference_id): |
| 131 | + if reference_id not in slice_reference_ids: |
| 132 | + raise ValueError("Got results outside the slice") |
| 133 | + for item in ds_items: |
| 134 | + if item.reference_id == reference_id: |
| 135 | + return item |
| 136 | + |
| 137 | + exported = slc.export_predictions(model) |
| 138 | + for row in exported: |
| 139 | + reference_id = row[ITEM_KEY].reference_id |
| 140 | + assert row[ITEM_KEY] == get_expected_item(reference_id) |
| 141 | + assert row[PREDICTIONS_KEY][BOX_TYPE][ |
| 142 | + 0 |
| 143 | + ] == get_expected_box_prediction(reference_id) |
| 144 | + |
| 145 | + |
92 | 146 | def test_slice_append(dataset): |
93 | 147 | ds_items = dataset.items |
94 | 148 |
|
|
0 commit comments