From 16cf202b930a370eecb31d65296a9fece19e26a9 Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Wed, 13 Nov 2024 15:23:05 -0500 Subject: [PATCH 1/9] Implement query parameters on main entities --- server/benchmark/views.py | 2 ++ server/benchmarkmodel/views.py | 1 + server/medperf/settings.py | 2 ++ server/mlcube/views.py | 2 ++ server/requirements.txt | 3 ++- server/result/views.py | 2 ++ 6 files changed, 11 insertions(+), 1 deletion(-) diff --git a/server/benchmark/views.py b/server/benchmark/views.py index 2ca3fd619..30687b191 100644 --- a/server/benchmark/views.py +++ b/server/benchmark/views.py @@ -15,6 +15,7 @@ class BenchmarkList(GenericAPIView): serializer_class = BenchmarkSerializer queryset = "" + filterset_fields = ('name', 'owner', 'state', 'is_valid', 'is_active', 'approval_status', 'data_preparation_mlcube') @extend_schema(operation_id="benchmarks_retrieve_all") def get(self, request, format=None): @@ -22,6 +23,7 @@ def get(self, request, format=None): List all benchmarks """ benchmarks = Benchmark.objects.all() + benchmarks = self.filter_queryset(benchmarks) benchmarks = self.paginate_queryset(benchmarks) serializer = BenchmarkSerializer(benchmarks, many=True) return self.get_paginated_response(serializer.data) diff --git a/server/benchmarkmodel/views.py b/server/benchmarkmodel/views.py index c4ec01301..55851acaf 100644 --- a/server/benchmarkmodel/views.py +++ b/server/benchmarkmodel/views.py @@ -16,6 +16,7 @@ class BenchmarkModelList(GenericAPIView): permission_classes = [IsAdmin | IsBenchmarkOwner | IsMlCubeOwner] serializer_class = BenchmarkModelListSerializer queryset = "" + filterset_fields = ('model_mlcube', 'benchmark', 'initiated_by', 'priority') def post(self, request, format=None): """ diff --git a/server/medperf/settings.py b/server/medperf/settings.py index a9c83f6fe..4e7d002fc 100644 --- a/server/medperf/settings.py +++ b/server/medperf/settings.py @@ -84,6 +84,7 @@ "django.contrib.sessions", "django.contrib.messages", "django.contrib.staticfiles", + "django_filters", "benchmark", "dataset", "benchmarkdataset", @@ -224,6 +225,7 @@ "DEFAULT_PERMISSION_CLASSES": ["rest_framework.permissions.IsAuthenticated"], "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.LimitOffsetPagination", "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.NamespaceVersioning", + 'DEFAULT_FILTER_BACKENDS': ['django_filters.rest_framework.DjangoFilterBackend'], "DEFAULT_PARSER_CLASSES": [ "rest_framework.parsers.JSONParser", ], diff --git a/server/mlcube/views.py b/server/mlcube/views.py index b46d8000f..28d672935 100644 --- a/server/mlcube/views.py +++ b/server/mlcube/views.py @@ -14,6 +14,7 @@ class MlCubeList(GenericAPIView): serializer_class = MlCubeSerializer queryset = "" + filterset_fields = ('name', 'owner', 'state', 'is_valid') @extend_schema(operation_id="mlcubes_retrieve_all") def get(self, request, format=None): @@ -21,6 +22,7 @@ def get(self, request, format=None): List all mlcubes """ mlcubes = MlCube.objects.all() + mlcubes = self.filter_queryset(mlcubes) mlcubes = self.paginate_queryset(mlcubes) serializer = MlCubeSerializer(mlcubes, many=True) return self.get_paginated_response(serializer.data) diff --git a/server/requirements.txt b/server/requirements.txt index 7ad3b76e2..f22eed666 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,4 +1,4 @@ -Django==4.2.11 +Django==4.2.16 djangorestframework==3.14.0 drf-spectacular==0.27.1 drf-spectacular-sidecar==2024.3.4 @@ -13,3 +13,4 @@ pyOpenSSL==24.1.0 Werkzeug==3.0.1 django-extensions==3.2.3 djangorestframework-simplejwt==5.3.1 +django-filter==24.3 \ No newline at end of file diff --git a/server/result/views.py b/server/result/views.py index a8d3debd3..08e9804ae 100644 --- a/server/result/views.py +++ b/server/result/views.py @@ -12,6 +12,7 @@ class ModelResultList(GenericAPIView): serializer_class = ModelResultSerializer queryset = "" + filterset_fields = ('name', 'owner', 'benchmark', 'model', 'dataset', 'is_valid', 'approval_status') def get_permissions(self): if self.request.method == "GET": @@ -26,6 +27,7 @@ def get(self, request, format=None): List all results """ modelresults = ModelResult.objects.all() + modelresults = self.filter_queryset(modelresults) modelresults = self.paginate_queryset(modelresults) serializer = ModelResultSerializer(modelresults, many=True) return self.get_paginated_response(serializer.data) From dcd46e3c007acec9f8dfdfde979952f91b9be7d2 Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Wed, 13 Nov 2024 17:55:10 -0500 Subject: [PATCH 2/9] revert django update --- server/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/requirements.txt b/server/requirements.txt index f22eed666..87d9c2b47 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,4 +1,4 @@ -Django==4.2.16 +Django==4.2.11 djangorestframework==3.14.0 drf-spectacular==0.27.1 drf-spectacular-sidecar==2024.3.4 From 3ca03a6f0dbeb3fdc96374383ec9d4a0c04bf987 Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Wed, 20 Nov 2024 11:43:36 -0500 Subject: [PATCH 3/9] Implement list query filtering in the CLI --- cli/medperf/commands/list.py | 3 +- cli/medperf/comms/rest.py | 70 +++++++++++++++++++------------ cli/medperf/entities/interface.py | 2 +- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index 99236ac3f..d8c6008ef 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -21,8 +21,9 @@ def run( Args: unregistered (bool, optional): Display only local unregistered results. Defaults to False. mine_only (bool, optional): Display all registered current-user results. Defaults to False. - kwargs (dict): Additional parameters for filtering entity lists. + kwargs (dict): Additional parameters for filtering entity lists. Keys with None will be filtered out. """ + kwargs = {k: v for k, v in kwargs.items() if v is not None} entity_list = EntityList( entity_class, fields, unregistered, mine_only, **kwargs ) diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 5ac236f93..3c6fc0c55 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -81,6 +81,7 @@ def __get_list( page_size=config.default_page_size, offset=0, binary_reduction=False, + filters={}, ): """Retrieves a list of elements from a URL by iterating over pages until num_elements is obtained. If num_elements is None, then iterates until all elements have been retrieved. @@ -104,7 +105,9 @@ def __get_list( num_elements = float("inf") while len(el_list) < num_elements: - paginated_url = f"{url}?limit={page_size}&offset={offset}" + filters.update({"limit": page_size, "offset": {offset}}) + query_str = "&".join([f"{k}={v}" for k, v in filters.items()]) + paginated_url = f"{url}?{query_str}" res = self.__auth_get(paginated_url) if res.status_code != 200: if not binary_reduction: @@ -152,13 +155,13 @@ def get_current_user(self): res = self.__auth_get(f"{self.server_url}/me/") return res.json() - def get_benchmarks(self) -> List[dict]: + def get_benchmarks(self, filters={}) -> List[dict]: """Retrieves all benchmarks in the platform. Returns: List[dict]: all benchmarks information. """ - bmks = self.__get_list(f"{self.server_url}/benchmarks/") + bmks = self.__get_list(f"{self.server_url}/benchmarks/", filters=filters) return bmks def get_benchmark(self, benchmark_uid: int) -> dict: @@ -179,7 +182,7 @@ def get_benchmark(self, benchmark_uid: int) -> dict: ) return res.json() - def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: + def get_benchmark_model_associations(self, benchmark_uid: int, filters={}) -> List[int]: """Retrieves all the model associations of a benchmark. Args: @@ -188,25 +191,28 @@ def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: Returns: list[int]: List of benchmark model associations """ - assocs = self.__get_list(f"{self.server_url}/benchmarks/{benchmark_uid}/models") + assocs = self.__get_list( + f"{self.server_url}/benchmarks/{benchmark_uid}/models", + filters=filters, + ) return filter_latest_associations(assocs, "model_mlcube") - def get_user_benchmarks(self) -> List[dict]: + def get_user_benchmarks(self, filters={}) -> List[dict]: """Retrieves all benchmarks created by the user Returns: List[dict]: Benchmarks data """ - bmks = self.__get_list(f"{self.server_url}/me/benchmarks/") + bmks = self.__get_list(f"{self.server_url}/me/benchmarks/", filters=filters) return bmks - def get_cubes(self) -> List[dict]: + def get_cubes(self, filters={}) -> List[dict]: """Retrieves all MLCubes in the platform Returns: List[dict]: List containing the data of all MLCubes """ - cubes = self.__get_list(f"{self.server_url}/mlcubes/") + cubes = self.__get_list(f"{self.server_url}/mlcubes/", filters=filters) return cubes def get_cube_metadata(self, cube_uid: int) -> dict: @@ -227,13 +233,13 @@ def get_cube_metadata(self, cube_uid: int) -> dict: ) return res.json() - def get_user_cubes(self) -> List[dict]: + def get_user_cubes(self, filters={}) -> List[dict]: """Retrieves metadata from all cubes registered by the user Returns: List[dict]: List of dictionaries containing the mlcubes registration information """ - cubes = self.__get_list(f"{self.server_url}/me/mlcubes/") + cubes = self.__get_list(f"{self.server_url}/me/mlcubes/", filters=filters) return cubes def upload_benchmark(self, benchmark_dict: dict) -> int: @@ -268,13 +274,13 @@ def upload_mlcube(self, mlcube_body: dict) -> int: raise CommunicationRetrievalError(f"Could not upload the mlcube: {details}") return res.json() - def get_datasets(self) -> List[dict]: + def get_datasets(self, filters={}) -> List[dict]: """Retrieves all datasets in the platform Returns: List[dict]: List of data from all datasets """ - dsets = self.__get_list(f"{self.server_url}/datasets/") + dsets = self.__get_list(f"{self.server_url}/datasets/", filters=filters) return dsets def get_dataset(self, dset_uid: int) -> dict: @@ -295,13 +301,13 @@ def get_dataset(self, dset_uid: int) -> dict: ) return res.json() - def get_user_datasets(self) -> dict: + def get_user_datasets(self, filters={}) -> dict: """Retrieves all datasets registered by the user Returns: dict: dictionary with the contents of each dataset registration query """ - dsets = self.__get_list(f"{self.server_url}/me/datasets/") + dsets = self.__get_list(f"{self.server_url}/me/datasets/", filters=filters) return dsets def upload_dataset(self, reg_dict: dict) -> int: @@ -320,13 +326,13 @@ def upload_dataset(self, reg_dict: dict) -> int: raise CommunicationRequestError(f"Could not upload the dataset: {details}") return res.json() - def get_results(self) -> List[dict]: + def get_results(self, filters={}) -> List[dict]: """Retrieves all results Returns: List[dict]: List of results """ - res = self.__get_list(f"{self.server_url}/results") + res = self.__get_list(f"{self.server_url}/results", filters=filters) if res.status_code != 200: log_response_error(res) details = format_errors_dict(res.json()) @@ -351,16 +357,16 @@ def get_result(self, result_uid: int) -> dict: ) return res.json() - def get_user_results(self) -> dict: + def get_user_results(self, filters={}) -> dict: """Retrieves all results registered by the user Returns: dict: dictionary with the contents of each result registration query """ - results = self.__get_list(f"{self.server_url}/me/results/") + results = self.__get_list(f"{self.server_url}/me/results/", filters=filters) return results - def get_benchmark_results(self, benchmark_id: int) -> dict: + def get_benchmark_results(self, benchmark_id: int, filters={}) -> dict: """Retrieves all results for a given benchmark Args: @@ -370,7 +376,8 @@ def get_benchmark_results(self, benchmark_id: int) -> dict: dict: dictionary with the contents of each result in the specified benchmark """ results = self.__get_list( - f"{self.server_url}/benchmarks/{benchmark_id}/results" + f"{self.server_url}/benchmarks/{benchmark_id}/results", + filters=filters, ) return results @@ -472,22 +479,28 @@ def set_mlcube_association_approval( f"Could not approve association between mlcube {mlcube_uid} and benchmark {benchmark_uid}: {details}" ) - def get_datasets_associations(self) -> List[dict]: + def get_datasets_associations(self, filters={}) -> List[dict]: """Get all dataset associations related to the current user Returns: List[dict]: List containing all associations information """ - assocs = self.__get_list(f"{self.server_url}/me/datasets/associations/") + assocs = self.__get_list( + f"{self.server_url}/me/datasets/associations/", + filters=filters, + ) return filter_latest_associations(assocs, "dataset") - def get_cubes_associations(self) -> List[dict]: + def get_cubes_associations(self, filters={}) -> List[dict]: """Get all cube associations related to the current user Returns: List[dict]: List containing all associations information """ - assocs = self.__get_list(f"{self.server_url}/me/mlcubes/associations/") + assocs = self.__get_list( + f"{self.server_url}/me/mlcubes/associations/", + filters=filters, + ) return filter_latest_associations(assocs, "model_mlcube") def set_mlcube_association_priority( @@ -519,7 +532,7 @@ def update_dataset(self, dataset_id: int, data: dict): raise CommunicationRequestError(f"Could not update dataset: {details}") return res.json() - def get_mlcube_datasets(self, mlcube_id: int) -> dict: + def get_mlcube_datasets(self, mlcube_id: int, filters={}) -> dict: """Retrieves all datasets that have the specified mlcube as the prep mlcube Args: @@ -529,7 +542,10 @@ def get_mlcube_datasets(self, mlcube_id: int) -> dict: dict: dictionary with the contents of each dataset """ - datasets = self.__get_list(f"{self.server_url}/mlcubes/{mlcube_id}/datasets/") + datasets = self.__get_list( + f"{self.server_url}/mlcubes/{mlcube_id}/datasets/", + filters=filters, + ) return datasets def get_user(self, user_id: int) -> dict: diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 835fbdf22..d348d9c92 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -75,7 +75,7 @@ def all( @classmethod def __remote_all(cls: Type[EntityType], filters: dict) -> List[EntityType]: comms_fn = cls.remote_prefilter(filters) - entity_meta = comms_fn() + entity_meta = comms_fn(filters) entities = [cls(**meta) for meta in entity_meta] return entities From 50cd30b705e1f7cff3423ed65065441a4579573d Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Wed, 20 Nov 2024 16:59:07 -0500 Subject: [PATCH 4/9] Add list filters to main entities --- cli/medperf/commands/benchmark/benchmark.py | 18 +++++++++++++++++- cli/medperf/commands/dataset/dataset.py | 8 ++++++++ cli/medperf/commands/mlcube/mlcube.py | 8 ++++++++ cli/medperf/commands/result/result.py | 8 ++++++++ 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index 35d719b0d..887ad7a2c 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -20,13 +20,29 @@ def list( False, "--unregistered", help="Get unregistered benchmarks" ), mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"), + name: str = typer.Option(None, "--name", help="Filter by name"), + owner: int = typer.Option(None, "--owner", help="Filter by owner"), + state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"), + is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"), + is_active: bool = typer.Option(None, "--active/--inactive", help="Filter by active status"), + data_prep: int = typer.Option(None, "-d", "--data-preparation-mlcube", help="Filter by Data Preparation MLCube"), ): """List benchmarks""" + filters = { + "name": name, + "owner": owner, + "state": state, + "is_valid": is_valid, + "is_active": is_active, + "data_preparation_mlcube": data_prep + } + EntityList.run( Benchmark, - fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"], + fields=["UID", "Name", "Description", "Data Preparation MLCube", "State", "Approval Status", "Registered"], unregistered=unregistered, mine_only=mine, + **filters, ) diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index fc18022ac..3d2fb64b4 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -24,6 +24,10 @@ def list( mlcube: int = typer.Option( None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube" ), + name: str = typer.Option(None, "--name", help="Filter by name"), + owner: int = typer.Option(None, "--owner", help="Filter by owner"), + state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"), + is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"), ): """List datasets""" EntityList.run( @@ -32,6 +36,10 @@ def list( unregistered=unregistered, mine_only=mine, mlcube=mlcube, + name=name, + owner=owner, + state=state, + is_valid=is_valid, ) diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 9256f35f2..fd7eb2287 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -20,6 +20,10 @@ def list( False, "--unregistered", help="Get unregistered mlcubes" ), mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"), + name: str = typer.Option(None, "--name", "-n", help="Filter out by MLCube Name"), + owner: int = typer.Option(None, "--owner", help="Filter by owner ID"), + state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"), + is_active: bool = typer.Option(None, "--active/--inactive", help="Filter by active status"), ): """List mlcubes""" EntityList.run( @@ -27,6 +31,10 @@ def list( fields=["UID", "Name", "State", "Registered"], unregistered=unregistered, mine_only=mine, + name=name, + owner=owner, + state=state, + is_active=is_active, ) diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 40b65c52e..7518015e7 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -69,6 +69,12 @@ def list( benchmark: int = typer.Option( None, "--benchmark", "-b", help="Get results for a given benchmark" ), + model: int = typer.Option( + None, "--owner", "-o", help="Get results for a given model" + ), + dataset: int = typer.Option( + None, "--dataset", "-d", help="Get reuslts for a given dataset" + ), ): """List results""" EntityList.run( @@ -77,6 +83,8 @@ def list( unregistered=unregistered, mine_only=mine, benchmark=benchmark, + model=model, + dataset=dataset, ) From 369ada4cea6aa28aa25366f88aa45dc323e54c11 Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Tue, 26 Nov 2024 12:33:59 -0500 Subject: [PATCH 5/9] remove owner query for /me/results --- cli/medperf/entities/result.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index 0e96d1feb..66327d05e 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -63,6 +63,7 @@ def remote_prefilter(filters: dict) -> callable: comms_fn = config.comms.get_results if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: comms_fn = config.comms.get_user_results + del filters["owner"] if "benchmark" in filters and filters["benchmark"] is not None: bmk = filters["benchmark"] From 3abb082209e7b249d08c9bda0bfea95b594fae5c Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Tue, 3 Dec 2024 12:40:06 -0500 Subject: [PATCH 6/9] Fix rest tests --- cli/medperf/comms/rest.py | 2 +- cli/medperf/tests/comms/test_rest.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 3c6fc0c55..91a75da97 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -105,7 +105,7 @@ def __get_list( num_elements = float("inf") while len(el_list) < num_elements: - filters.update({"limit": page_size, "offset": {offset}}) + filters.update({"limit": page_size, "offset": offset}) query_str = "&".join([f"{k}={v}" for k, v in filters.items()]) paginated_url = f"{url}?{query_str}" res = self.__auth_get(paginated_url) diff --git a/cli/medperf/tests/comms/test_rest.py b/cli/medperf/tests/comms/test_rest.py index fb3596c98..b9e28e921 100644 --- a/cli/medperf/tests/comms/test_rest.py +++ b/cli/medperf/tests/comms/test_rest.py @@ -30,7 +30,7 @@ def server(mocker, ui): [1], [], (f"{full_url}/benchmarks/1/models",), - {}, + {"filters": {}}, ), ("get_cube_metadata", "get", 200, [1], {}, (f"{full_url}/mlcubes/1/",), {}), ( @@ -284,7 +284,7 @@ def test_get_benchmarks_calls_benchmarks_path(mocker, server, body): bmarks = server.get_benchmarks() # Assert - spy.assert_called_once_with(f"{full_url}/benchmarks/") + spy.assert_called_once_with(f"{full_url}/benchmarks/", filters={}) assert bmarks == [body] @@ -319,7 +319,7 @@ def test_get_user_benchmarks_calls_auth_get_for_expected_path(mocker, server): server.get_user_benchmarks() # Assert - spy.assert_called_once_with(f"{full_url}/me/benchmarks/") + spy.assert_called_once_with(f"{full_url}/me/benchmarks/", filters={}) def test_get_user_benchmarks_returns_benchmarks(mocker, server): @@ -346,7 +346,7 @@ def test_get_mlcubes_calls_mlcubes_path(mocker, server, body): cubes = server.get_cubes() # Assert - spy.assert_called_once_with(f"{full_url}/mlcubes/") + spy.assert_called_once_with(f"{full_url}/mlcubes/", filters={}) assert cubes == [body] @@ -375,7 +375,7 @@ def test_get_user_cubes_calls_auth_get_for_expected_path(mocker, server): server.get_user_cubes() # Assert - spy.assert_called_once_with(f"{full_url}/me/mlcubes/") + spy.assert_called_once_with(f"{full_url}/me/mlcubes/", filters={}) @pytest.mark.parametrize("body", [{"dset": 1}, {}, {"test": "test"}]) @@ -387,7 +387,7 @@ def test_get_datasets_calls_datasets_path(mocker, server, body): dsets = server.get_datasets() # Assert - spy.assert_called_once_with(f"{full_url}/datasets/") + spy.assert_called_once_with(f"{full_url}/datasets/", filters={}) assert dsets == [body] @@ -418,7 +418,7 @@ def test_get_user_datasets_calls_auth_get_for_expected_path(mocker, server): server.get_user_datasets() # Assert - spy.assert_called_once_with(f"{full_url}/me/datasets/") + spy.assert_called_once_with(f"{full_url}/me/datasets/", filters={}) @pytest.mark.parametrize("body", [{"mlcube": 1}, {}, {"test": "test"}]) @@ -529,7 +529,7 @@ def test_get_datasets_associations_gets_associations(mocker, server): server.get_datasets_associations() # Assert - spy.assert_called_once_with(exp_path) + spy.assert_called_once_with(exp_path, filters={}) def test_get_cubes_associations_gets_associations(mocker, server): @@ -541,7 +541,7 @@ def test_get_cubes_associations_gets_associations(mocker, server): server.get_cubes_associations() # Assert - spy.assert_called_once_with(exp_path) + spy.assert_called_once_with(exp_path, filters={}) @pytest.mark.parametrize("uid", [448, 53, 312]) @@ -621,5 +621,5 @@ def test_get_mlcube_datasets_calls_auth_get_for_expected_path(mocker, server): exp_datasets = server.get_mlcube_datasets(cube_id) # Assert - spy.assert_called_once_with(f"{full_url}/mlcubes/{cube_id}/datasets/") + spy.assert_called_once_with(f"{full_url}/mlcubes/{cube_id}/datasets/", filters={}) assert exp_datasets == datasets From f5bd3432834c419cde9a6543ba2cca56fa55792a Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Fri, 13 Dec 2024 12:41:35 -0500 Subject: [PATCH 7/9] Add test for query params --- server/benchmark/tests/test_.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/server/benchmark/tests/test_.py b/server/benchmark/tests/test_.py index d2a22445d..b8cc5e26c 100644 --- a/server/benchmark/tests/test_.py +++ b/server/benchmark/tests/test_.py @@ -281,6 +281,30 @@ def test_generic_get_benchmark_list(self): self.assertEqual(len(response.data["results"]), 1) self.assertEqual(response.data["results"][0]["id"], benchmark_id) + @parameterized.expand([ + ({"name": "bmk2"}, 1), + ({"name": "string", "owner": 3}, 1), + ({"name": "nonexistent"}, 0), + ({"is_valid": True}, 2), + ({"data_preparation_mlcube": 2}, 1), + ]) + def test_get_with_query_params(self, query_dict, expected_num_results): + # Arrange + # Create another benchmark + benchmark = self.mock_benchmark( + self.ref_model["id"], self.prep["id"], self.eval["id"], name="bmk2" + ) + benchmark = self.create_benchmark(benchmark).data + query_str = "&".join([f"{k}={v}" for k, v in query_dict.items()]) + + # Act + query_url = self.url + f"?{query_str}" + response = self.client.get(query_url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), expected_num_results) + class PermissionTest(BenchmarkTest): """Test module for permissions of /benchmarks/ endpoint From 6bdcfdff24b3e8b18904415d8163c34fb85b07b3 Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Wed, 18 Dec 2024 14:50:25 -0500 Subject: [PATCH 8/9] Fix tests for postgres deployment --- server/benchmark/tests/test_.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/server/benchmark/tests/test_.py b/server/benchmark/tests/test_.py index b8cc5e26c..6ed863634 100644 --- a/server/benchmark/tests/test_.py +++ b/server/benchmark/tests/test_.py @@ -283,10 +283,10 @@ def test_generic_get_benchmark_list(self): @parameterized.expand([ ({"name": "bmk2"}, 1), - ({"name": "string", "owner": 3}, 1), + ({"name": "bmk2", "owner": None}, 1), # id assigned dynamically ({"name": "nonexistent"}, 0), ({"is_valid": True}, 2), - ({"data_preparation_mlcube": 2}, 1), + ({"data_preparation_mlcube": None}, 1), # id assigned dinamically ]) def test_get_with_query_params(self, query_dict, expected_num_results): # Arrange @@ -295,11 +295,17 @@ def test_get_with_query_params(self, query_dict, expected_num_results): self.ref_model["id"], self.prep["id"], self.eval["id"], name="bmk2" ) benchmark = self.create_benchmark(benchmark).data + if 'owner' in query_dict: + query_dict['owner'] = benchmark['owner'] + if 'data_preparation_mlcube' in query_dict: + query_dict['data_preparation_mlcube'] = self.prep['id'] query_str = "&".join([f"{k}={v}" for k, v in query_dict.items()]) # Act query_url = self.url + f"?{query_str}" response = self.client.get(query_url) + print(query_url, response.status_code, response.data['results']) + print("Unfiltered Query:", self.client.get(self.url).data['results']) # Assert self.assertEqual(response.status_code, status.HTTP_200_OK) From 926244b990c13a054ff3519236b67278a08b1929 Mon Sep 17 00:00:00 2001 From: Alejandro Aristizabal Date: Wed, 18 Dec 2024 14:50:45 -0500 Subject: [PATCH 9/9] Remove unnecessary prints --- server/benchmark/tests/test_.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/benchmark/tests/test_.py b/server/benchmark/tests/test_.py index 6ed863634..6a12f503b 100644 --- a/server/benchmark/tests/test_.py +++ b/server/benchmark/tests/test_.py @@ -304,8 +304,6 @@ def test_get_with_query_params(self, query_dict, expected_num_results): # Act query_url = self.url + f"?{query_str}" response = self.client.get(query_url) - print(query_url, response.status_code, response.data['results']) - print("Unfiltered Query:", self.client.get(self.url).data['results']) # Assert self.assertEqual(response.status_code, status.HTTP_200_OK)