diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index 35d719b0d..887ad7a2c 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -20,13 +20,29 @@ def list( False, "--unregistered", help="Get unregistered benchmarks" ), mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"), + name: str = typer.Option(None, "--name", help="Filter by name"), + owner: int = typer.Option(None, "--owner", help="Filter by owner"), + state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"), + is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"), + is_active: bool = typer.Option(None, "--active/--inactive", help="Filter by active status"), + data_prep: int = typer.Option(None, "-d", "--data-preparation-mlcube", help="Filter by Data Preparation MLCube"), ): """List benchmarks""" + filters = { + "name": name, + "owner": owner, + "state": state, + "is_valid": is_valid, + "is_active": is_active, + "data_preparation_mlcube": data_prep + } + EntityList.run( Benchmark, - fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"], + fields=["UID", "Name", "Description", "Data Preparation MLCube", "State", "Approval Status", "Registered"], unregistered=unregistered, mine_only=mine, + **filters, ) diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index fc18022ac..3d2fb64b4 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -24,6 +24,10 @@ def list( mlcube: int = typer.Option( None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube" ), + name: str = typer.Option(None, "--name", help="Filter by name"), + owner: int = typer.Option(None, "--owner", help="Filter by owner"), + state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"), + is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"), ): """List datasets""" EntityList.run( @@ -32,6 +36,10 @@ def list( unregistered=unregistered, mine_only=mine, mlcube=mlcube, + name=name, + owner=owner, + state=state, + is_valid=is_valid, ) diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index 99236ac3f..d8c6008ef 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -21,8 +21,9 @@ def run( Args: unregistered (bool, optional): Display only local unregistered results. Defaults to False. mine_only (bool, optional): Display all registered current-user results. Defaults to False. - kwargs (dict): Additional parameters for filtering entity lists. + kwargs (dict): Additional parameters for filtering entity lists. Keys with None will be filtered out. """ + kwargs = {k: v for k, v in kwargs.items() if v is not None} entity_list = EntityList( entity_class, fields, unregistered, mine_only, **kwargs ) diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 9256f35f2..fd7eb2287 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -20,6 +20,10 @@ def list( False, "--unregistered", help="Get unregistered mlcubes" ), mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"), + name: str = typer.Option(None, "--name", "-n", help="Filter out by MLCube Name"), + owner: int = typer.Option(None, "--owner", help="Filter by owner ID"), + state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"), + is_active: bool = typer.Option(None, "--active/--inactive", help="Filter by active status"), ): """List mlcubes""" EntityList.run( @@ -27,6 +31,10 @@ def list( fields=["UID", "Name", "State", "Registered"], unregistered=unregistered, mine_only=mine, + name=name, + owner=owner, + state=state, + is_active=is_active, ) diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 40b65c52e..7518015e7 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -69,6 +69,12 @@ def list( benchmark: int = typer.Option( None, "--benchmark", "-b", help="Get results for a given benchmark" ), + model: int = typer.Option( + None, "--owner", "-o", help="Get results for a given model" + ), + dataset: int = typer.Option( + None, "--dataset", "-d", help="Get reuslts for a given dataset" + ), ): """List results""" EntityList.run( @@ -77,6 +83,8 @@ def list( unregistered=unregistered, mine_only=mine, benchmark=benchmark, + model=model, + dataset=dataset, ) diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 5ac236f93..91a75da97 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -81,6 +81,7 @@ def __get_list( page_size=config.default_page_size, offset=0, binary_reduction=False, + filters={}, ): """Retrieves a list of elements from a URL by iterating over pages until num_elements is obtained. If num_elements is None, then iterates until all elements have been retrieved. @@ -104,7 +105,9 @@ def __get_list( num_elements = float("inf") while len(el_list) < num_elements: - paginated_url = f"{url}?limit={page_size}&offset={offset}" + filters.update({"limit": page_size, "offset": offset}) + query_str = "&".join([f"{k}={v}" for k, v in filters.items()]) + paginated_url = f"{url}?{query_str}" res = self.__auth_get(paginated_url) if res.status_code != 200: if not binary_reduction: @@ -152,13 +155,13 @@ def get_current_user(self): res = self.__auth_get(f"{self.server_url}/me/") return res.json() - def get_benchmarks(self) -> List[dict]: + def get_benchmarks(self, filters={}) -> List[dict]: """Retrieves all benchmarks in the platform. Returns: List[dict]: all benchmarks information. """ - bmks = self.__get_list(f"{self.server_url}/benchmarks/") + bmks = self.__get_list(f"{self.server_url}/benchmarks/", filters=filters) return bmks def get_benchmark(self, benchmark_uid: int) -> dict: @@ -179,7 +182,7 @@ def get_benchmark(self, benchmark_uid: int) -> dict: ) return res.json() - def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: + def get_benchmark_model_associations(self, benchmark_uid: int, filters={}) -> List[int]: """Retrieves all the model associations of a benchmark. Args: @@ -188,25 +191,28 @@ def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: Returns: list[int]: List of benchmark model associations """ - assocs = self.__get_list(f"{self.server_url}/benchmarks/{benchmark_uid}/models") + assocs = self.__get_list( + f"{self.server_url}/benchmarks/{benchmark_uid}/models", + filters=filters, + ) return filter_latest_associations(assocs, "model_mlcube") - def get_user_benchmarks(self) -> List[dict]: + def get_user_benchmarks(self, filters={}) -> List[dict]: """Retrieves all benchmarks created by the user Returns: List[dict]: Benchmarks data """ - bmks = self.__get_list(f"{self.server_url}/me/benchmarks/") + bmks = self.__get_list(f"{self.server_url}/me/benchmarks/", filters=filters) return bmks - def get_cubes(self) -> List[dict]: + def get_cubes(self, filters={}) -> List[dict]: """Retrieves all MLCubes in the platform Returns: List[dict]: List containing the data of all MLCubes """ - cubes = self.__get_list(f"{self.server_url}/mlcubes/") + cubes = self.__get_list(f"{self.server_url}/mlcubes/", filters=filters) return cubes def get_cube_metadata(self, cube_uid: int) -> dict: @@ -227,13 +233,13 @@ def get_cube_metadata(self, cube_uid: int) -> dict: ) return res.json() - def get_user_cubes(self) -> List[dict]: + def get_user_cubes(self, filters={}) -> List[dict]: """Retrieves metadata from all cubes registered by the user Returns: List[dict]: List of dictionaries containing the mlcubes registration information """ - cubes = self.__get_list(f"{self.server_url}/me/mlcubes/") + cubes = self.__get_list(f"{self.server_url}/me/mlcubes/", filters=filters) return cubes def upload_benchmark(self, benchmark_dict: dict) -> int: @@ -268,13 +274,13 @@ def upload_mlcube(self, mlcube_body: dict) -> int: raise CommunicationRetrievalError(f"Could not upload the mlcube: {details}") return res.json() - def get_datasets(self) -> List[dict]: + def get_datasets(self, filters={}) -> List[dict]: """Retrieves all datasets in the platform Returns: List[dict]: List of data from all datasets """ - dsets = self.__get_list(f"{self.server_url}/datasets/") + dsets = self.__get_list(f"{self.server_url}/datasets/", filters=filters) return dsets def get_dataset(self, dset_uid: int) -> dict: @@ -295,13 +301,13 @@ def get_dataset(self, dset_uid: int) -> dict: ) return res.json() - def get_user_datasets(self) -> dict: + def get_user_datasets(self, filters={}) -> dict: """Retrieves all datasets registered by the user Returns: dict: dictionary with the contents of each dataset registration query """ - dsets = self.__get_list(f"{self.server_url}/me/datasets/") + dsets = self.__get_list(f"{self.server_url}/me/datasets/", filters=filters) return dsets def upload_dataset(self, reg_dict: dict) -> int: @@ -320,13 +326,13 @@ def upload_dataset(self, reg_dict: dict) -> int: raise CommunicationRequestError(f"Could not upload the dataset: {details}") return res.json() - def get_results(self) -> List[dict]: + def get_results(self, filters={}) -> List[dict]: """Retrieves all results Returns: List[dict]: List of results """ - res = self.__get_list(f"{self.server_url}/results") + res = self.__get_list(f"{self.server_url}/results", filters=filters) if res.status_code != 200: log_response_error(res) details = format_errors_dict(res.json()) @@ -351,16 +357,16 @@ def get_result(self, result_uid: int) -> dict: ) return res.json() - def get_user_results(self) -> dict: + def get_user_results(self, filters={}) -> dict: """Retrieves all results registered by the user Returns: dict: dictionary with the contents of each result registration query """ - results = self.__get_list(f"{self.server_url}/me/results/") + results = self.__get_list(f"{self.server_url}/me/results/", filters=filters) return results - def get_benchmark_results(self, benchmark_id: int) -> dict: + def get_benchmark_results(self, benchmark_id: int, filters={}) -> dict: """Retrieves all results for a given benchmark Args: @@ -370,7 +376,8 @@ def get_benchmark_results(self, benchmark_id: int) -> dict: dict: dictionary with the contents of each result in the specified benchmark """ results = self.__get_list( - f"{self.server_url}/benchmarks/{benchmark_id}/results" + f"{self.server_url}/benchmarks/{benchmark_id}/results", + filters=filters, ) return results @@ -472,22 +479,28 @@ def set_mlcube_association_approval( f"Could not approve association between mlcube {mlcube_uid} and benchmark {benchmark_uid}: {details}" ) - def get_datasets_associations(self) -> List[dict]: + def get_datasets_associations(self, filters={}) -> List[dict]: """Get all dataset associations related to the current user Returns: List[dict]: List containing all associations information """ - assocs = self.__get_list(f"{self.server_url}/me/datasets/associations/") + assocs = self.__get_list( + f"{self.server_url}/me/datasets/associations/", + filters=filters, + ) return filter_latest_associations(assocs, "dataset") - def get_cubes_associations(self) -> List[dict]: + def get_cubes_associations(self, filters={}) -> List[dict]: """Get all cube associations related to the current user Returns: List[dict]: List containing all associations information """ - assocs = self.__get_list(f"{self.server_url}/me/mlcubes/associations/") + assocs = self.__get_list( + f"{self.server_url}/me/mlcubes/associations/", + filters=filters, + ) return filter_latest_associations(assocs, "model_mlcube") def set_mlcube_association_priority( @@ -519,7 +532,7 @@ def update_dataset(self, dataset_id: int, data: dict): raise CommunicationRequestError(f"Could not update dataset: {details}") return res.json() - def get_mlcube_datasets(self, mlcube_id: int) -> dict: + def get_mlcube_datasets(self, mlcube_id: int, filters={}) -> dict: """Retrieves all datasets that have the specified mlcube as the prep mlcube Args: @@ -529,7 +542,10 @@ def get_mlcube_datasets(self, mlcube_id: int) -> dict: dict: dictionary with the contents of each dataset """ - datasets = self.__get_list(f"{self.server_url}/mlcubes/{mlcube_id}/datasets/") + datasets = self.__get_list( + f"{self.server_url}/mlcubes/{mlcube_id}/datasets/", + filters=filters, + ) return datasets def get_user(self, user_id: int) -> dict: diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 835fbdf22..d348d9c92 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -75,7 +75,7 @@ def all( @classmethod def __remote_all(cls: Type[EntityType], filters: dict) -> List[EntityType]: comms_fn = cls.remote_prefilter(filters) - entity_meta = comms_fn() + entity_meta = comms_fn(filters) entities = [cls(**meta) for meta in entity_meta] return entities diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index 0e96d1feb..66327d05e 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -63,6 +63,7 @@ def remote_prefilter(filters: dict) -> callable: comms_fn = config.comms.get_results if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: comms_fn = config.comms.get_user_results + del filters["owner"] if "benchmark" in filters and filters["benchmark"] is not None: bmk = filters["benchmark"] diff --git a/cli/medperf/tests/comms/test_rest.py b/cli/medperf/tests/comms/test_rest.py index fb3596c98..b9e28e921 100644 --- a/cli/medperf/tests/comms/test_rest.py +++ b/cli/medperf/tests/comms/test_rest.py @@ -30,7 +30,7 @@ def server(mocker, ui): [1], [], (f"{full_url}/benchmarks/1/models",), - {}, + {"filters": {}}, ), ("get_cube_metadata", "get", 200, [1], {}, (f"{full_url}/mlcubes/1/",), {}), ( @@ -284,7 +284,7 @@ def test_get_benchmarks_calls_benchmarks_path(mocker, server, body): bmarks = server.get_benchmarks() # Assert - spy.assert_called_once_with(f"{full_url}/benchmarks/") + spy.assert_called_once_with(f"{full_url}/benchmarks/", filters={}) assert bmarks == [body] @@ -319,7 +319,7 @@ def test_get_user_benchmarks_calls_auth_get_for_expected_path(mocker, server): server.get_user_benchmarks() # Assert - spy.assert_called_once_with(f"{full_url}/me/benchmarks/") + spy.assert_called_once_with(f"{full_url}/me/benchmarks/", filters={}) def test_get_user_benchmarks_returns_benchmarks(mocker, server): @@ -346,7 +346,7 @@ def test_get_mlcubes_calls_mlcubes_path(mocker, server, body): cubes = server.get_cubes() # Assert - spy.assert_called_once_with(f"{full_url}/mlcubes/") + spy.assert_called_once_with(f"{full_url}/mlcubes/", filters={}) assert cubes == [body] @@ -375,7 +375,7 @@ def test_get_user_cubes_calls_auth_get_for_expected_path(mocker, server): server.get_user_cubes() # Assert - spy.assert_called_once_with(f"{full_url}/me/mlcubes/") + spy.assert_called_once_with(f"{full_url}/me/mlcubes/", filters={}) @pytest.mark.parametrize("body", [{"dset": 1}, {}, {"test": "test"}]) @@ -387,7 +387,7 @@ def test_get_datasets_calls_datasets_path(mocker, server, body): dsets = server.get_datasets() # Assert - spy.assert_called_once_with(f"{full_url}/datasets/") + spy.assert_called_once_with(f"{full_url}/datasets/", filters={}) assert dsets == [body] @@ -418,7 +418,7 @@ def test_get_user_datasets_calls_auth_get_for_expected_path(mocker, server): server.get_user_datasets() # Assert - spy.assert_called_once_with(f"{full_url}/me/datasets/") + spy.assert_called_once_with(f"{full_url}/me/datasets/", filters={}) @pytest.mark.parametrize("body", [{"mlcube": 1}, {}, {"test": "test"}]) @@ -529,7 +529,7 @@ def test_get_datasets_associations_gets_associations(mocker, server): server.get_datasets_associations() # Assert - spy.assert_called_once_with(exp_path) + spy.assert_called_once_with(exp_path, filters={}) def test_get_cubes_associations_gets_associations(mocker, server): @@ -541,7 +541,7 @@ def test_get_cubes_associations_gets_associations(mocker, server): server.get_cubes_associations() # Assert - spy.assert_called_once_with(exp_path) + spy.assert_called_once_with(exp_path, filters={}) @pytest.mark.parametrize("uid", [448, 53, 312]) @@ -621,5 +621,5 @@ def test_get_mlcube_datasets_calls_auth_get_for_expected_path(mocker, server): exp_datasets = server.get_mlcube_datasets(cube_id) # Assert - spy.assert_called_once_with(f"{full_url}/mlcubes/{cube_id}/datasets/") + spy.assert_called_once_with(f"{full_url}/mlcubes/{cube_id}/datasets/", filters={}) assert exp_datasets == datasets diff --git a/server/benchmark/tests/test_.py b/server/benchmark/tests/test_.py index d2a22445d..6a12f503b 100644 --- a/server/benchmark/tests/test_.py +++ b/server/benchmark/tests/test_.py @@ -281,6 +281,34 @@ def test_generic_get_benchmark_list(self): self.assertEqual(len(response.data["results"]), 1) self.assertEqual(response.data["results"][0]["id"], benchmark_id) + @parameterized.expand([ + ({"name": "bmk2"}, 1), + ({"name": "bmk2", "owner": None}, 1), # id assigned dynamically + ({"name": "nonexistent"}, 0), + ({"is_valid": True}, 2), + ({"data_preparation_mlcube": None}, 1), # id assigned dinamically + ]) + def test_get_with_query_params(self, query_dict, expected_num_results): + # Arrange + # Create another benchmark + benchmark = self.mock_benchmark( + self.ref_model["id"], self.prep["id"], self.eval["id"], name="bmk2" + ) + benchmark = self.create_benchmark(benchmark).data + if 'owner' in query_dict: + query_dict['owner'] = benchmark['owner'] + if 'data_preparation_mlcube' in query_dict: + query_dict['data_preparation_mlcube'] = self.prep['id'] + query_str = "&".join([f"{k}={v}" for k, v in query_dict.items()]) + + # Act + query_url = self.url + f"?{query_str}" + response = self.client.get(query_url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), expected_num_results) + class PermissionTest(BenchmarkTest): """Test module for permissions of /benchmarks/ endpoint diff --git a/server/benchmark/views.py b/server/benchmark/views.py index 2ca3fd619..30687b191 100644 --- a/server/benchmark/views.py +++ b/server/benchmark/views.py @@ -15,6 +15,7 @@ class BenchmarkList(GenericAPIView): serializer_class = BenchmarkSerializer queryset = "" + filterset_fields = ('name', 'owner', 'state', 'is_valid', 'is_active', 'approval_status', 'data_preparation_mlcube') @extend_schema(operation_id="benchmarks_retrieve_all") def get(self, request, format=None): @@ -22,6 +23,7 @@ def get(self, request, format=None): List all benchmarks """ benchmarks = Benchmark.objects.all() + benchmarks = self.filter_queryset(benchmarks) benchmarks = self.paginate_queryset(benchmarks) serializer = BenchmarkSerializer(benchmarks, many=True) return self.get_paginated_response(serializer.data) diff --git a/server/benchmarkmodel/views.py b/server/benchmarkmodel/views.py index c4ec01301..55851acaf 100644 --- a/server/benchmarkmodel/views.py +++ b/server/benchmarkmodel/views.py @@ -16,6 +16,7 @@ class BenchmarkModelList(GenericAPIView): permission_classes = [IsAdmin | IsBenchmarkOwner | IsMlCubeOwner] serializer_class = BenchmarkModelListSerializer queryset = "" + filterset_fields = ('model_mlcube', 'benchmark', 'initiated_by', 'priority') def post(self, request, format=None): """ diff --git a/server/medperf/settings.py b/server/medperf/settings.py index a9c83f6fe..4e7d002fc 100644 --- a/server/medperf/settings.py +++ b/server/medperf/settings.py @@ -84,6 +84,7 @@ "django.contrib.sessions", "django.contrib.messages", "django.contrib.staticfiles", + "django_filters", "benchmark", "dataset", "benchmarkdataset", @@ -224,6 +225,7 @@ "DEFAULT_PERMISSION_CLASSES": ["rest_framework.permissions.IsAuthenticated"], "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.LimitOffsetPagination", "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.NamespaceVersioning", + 'DEFAULT_FILTER_BACKENDS': ['django_filters.rest_framework.DjangoFilterBackend'], "DEFAULT_PARSER_CLASSES": [ "rest_framework.parsers.JSONParser", ], diff --git a/server/mlcube/views.py b/server/mlcube/views.py index b46d8000f..28d672935 100644 --- a/server/mlcube/views.py +++ b/server/mlcube/views.py @@ -14,6 +14,7 @@ class MlCubeList(GenericAPIView): serializer_class = MlCubeSerializer queryset = "" + filterset_fields = ('name', 'owner', 'state', 'is_valid') @extend_schema(operation_id="mlcubes_retrieve_all") def get(self, request, format=None): @@ -21,6 +22,7 @@ def get(self, request, format=None): List all mlcubes """ mlcubes = MlCube.objects.all() + mlcubes = self.filter_queryset(mlcubes) mlcubes = self.paginate_queryset(mlcubes) serializer = MlCubeSerializer(mlcubes, many=True) return self.get_paginated_response(serializer.data) diff --git a/server/requirements.txt b/server/requirements.txt index 7ad3b76e2..87d9c2b47 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -13,3 +13,4 @@ pyOpenSSL==24.1.0 Werkzeug==3.0.1 django-extensions==3.2.3 djangorestframework-simplejwt==5.3.1 +django-filter==24.3 \ No newline at end of file diff --git a/server/result/views.py b/server/result/views.py index a8d3debd3..08e9804ae 100644 --- a/server/result/views.py +++ b/server/result/views.py @@ -12,6 +12,7 @@ class ModelResultList(GenericAPIView): serializer_class = ModelResultSerializer queryset = "" + filterset_fields = ('name', 'owner', 'benchmark', 'model', 'dataset', 'is_valid', 'approval_status') def get_permissions(self): if self.request.method == "GET": @@ -26,6 +27,7 @@ def get(self, request, format=None): List all results """ modelresults = ModelResult.objects.all() + modelresults = self.filter_queryset(modelresults) modelresults = self.paginate_queryset(modelresults) serializer = ModelResultSerializer(modelresults, many=True) return self.get_paginated_response(serializer.data)