diff --git a/label_studio/data_import/api.py b/label_studio/data_import/api.py index 35592589186d..5df936329829 100644 --- a/label_studio/data_import/api.py +++ b/label_studio/data_import/api.py @@ -415,7 +415,48 @@ def create(self, request, *args, **kwargs): # Import -@extend_schema(exclude=True) +@method_decorator( + name='post', + decorator=extend_schema( + tags=['Import'], + summary='Import predictions', + description='Import model predictions for tasks in the specified project.', + parameters=[ + OpenApiParameter( + name='id', + type=OpenApiTypes.INT, + location='path', + description='A unique integer value identifying this project.', + ), + ], + request=PredictionSerializer(many=True), + responses={ + 201: OpenApiResponse( + description='Predictions successfully imported', + response={ + 'title': 'Predictions import response', + 'description': 'Import result', + 'type': 'object', + 'properties': { + 'created': { + 'title': 'created', + 'description': 'Number of predictions created', + 'type': 'integer', + } + }, + }, + ), + 400: OpenApiResponse( + description='Bad Request', + ), + }, + extensions={ + 'x-fern-sdk-group-name': 'projects', + 'x-fern-sdk-method-name': 'import_predictions', + 'x-fern-audiences': ['public'], + }, + ), +) class ImportPredictionsAPI(generics.CreateAPIView): """ API for importing predictions to a project. diff --git a/label_studio/tests/sdk/test_predictions.py b/label_studio/tests/sdk/test_predictions.py index b0acecf36cfa..8e6f76292584 100644 --- a/label_studio/tests/sdk/test_predictions.py +++ b/label_studio/tests/sdk/test_predictions.py @@ -128,3 +128,57 @@ def test_create_predictions_with_import(django_live_url, business_client): ls.projects.update(id=p.id, model_version='3.4.6') assert e.value.status_code == 400 assert e.value.body['validation_errors']['model_version'][0].startswith("Model version doesn't exist") + + +def test_projects_import_predictions(django_live_url, business_client): + """Import multiple predictions via projects.import_predictions + + Purpose: + - Verify that the bulk predictions import endpoint creates predictions for a project + + Setup: + - Create a project with a simple text classification config + - Create one task in the project + + Actions: + - Call ls.projects.import_predictions with three predictions for the same task + + Validations: + - API returns created == 3 + - Listing predictions for the task returns exactly three items with expected model versions + """ + + ls = LabelStudio(base_url=django_live_url, api_key=business_client.api_key) + li = LabelInterface(LABEL_CONFIG_AND_TASKS['label_config']) + project = ls.projects.create(title='Predictions Import Project', label_config=li.config) + + task = ls.tasks.create(project=project.id, data={'my_text': 'Classify this sentence'}) + + model_versions = ['humor__gpt-5-mini', 'humor__gpt-4.1-mini', 'humor__gpt-4o-mini'] + choices = [['Positive'], ['Neutral'], ['Negative']] + + predictions_payload = [] + for mv, ch in zip(model_versions, choices): + predictions_payload.append( + { + 'result': [ + { + 'from_name': 'sentiment_class', + 'to_name': 'message', + 'type': 'choices', + 'value': {'choices': ch}, + } + ], + 'model_version': mv, + 'score': 1, + 'task': task.id, + } + ) + + response = ls.projects.import_predictions(id=project.id, request=predictions_payload) + assert response.created == 3 + + preds = ls.predictions.list(task=task.id) + assert len(preds) == 3 + returned_versions = sorted([p.model_version for p in preds]) + assert returned_versions == sorted(model_versions) diff --git a/poetry.lock b/poetry.lock index 5be45b644de0..857574a0d690 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2136,7 +2136,7 @@ optional = false python-versions = ">=3.9,<4" groups = ["main"] files = [ - {file = "90f4274c21a3ce6304b883b01b14f4aa6af81e41.zip", hash = "sha256:76e026573e83d05ee3b8765805340a58b14e62483a0349448feb72f041ee1ca8"}, + {file = "83f2c2ee4d7daabae5f5d182e65502e754390ea0.zip", hash = "sha256:fa7dc3d57d5e1b2ee444f569741c210448234a5f48912246b43bb12d6c7a9262"}, ] [package.dependencies] @@ -2164,7 +2164,7 @@ xmljson = "0.2.1" [package.source] type = "url" -url = "https://github.com/HumanSignal/label-studio-sdk/archive/90f4274c21a3ce6304b883b01b14f4aa6af81e41.zip" +url = "https://github.com/HumanSignal/label-studio-sdk/archive/83f2c2ee4d7daabae5f5d182e65502e754390ea0.zip" [[package]] name = "launchdarkly-server-sdk" @@ -5109,4 +5109,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "acbad97fcd06d243f46b5a0b8bd8f9392400c9657f97c24979b18c1a6c3d48da" +content-hash = "2e6ba98bdd4b6ef95b3cd0b103cac0327d455270725ade72c4abbd1e014459da" diff --git a/pyproject.toml b/pyproject.toml index ac182c316ead..431a9da04219 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "tldextract (>=5.1.3)", "uuid-utils (>=0.11.0,<1.0.0)", ## HumanSignal repo dependencies :start - "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/90f4274c21a3ce6304b883b01b14f4aa6af81e41.zip", + "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/83f2c2ee4d7daabae5f5d182e65502e754390ea0.zip", ## HumanSignal repo dependencies :end ]