Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions cli/medperf/comms/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ def __req(self, url, req_func, **kwargs):
"remember to provide the server certificate through --certificate"
)

def __get_count(self, url, filters={}, error_msg="") -> int:
filters = dict(filters)
filters.update({"is_valid": True, "limit": 1, "offset": 0})

query_str = "&".join([f"{k}={v}" for k, v in filters.items()])
paginated_url = f"{url}?{query_str}"

res = self.__auth_get(paginated_url)
if res.status_code != 200:
log_response_error(res)
details = format_errors_dict(res.json())
raise CommunicationRetrievalError(f"{error_msg}: {details}")

return res.json()["count"]

def __get_list(
self,
url,
Expand All @@ -95,6 +110,15 @@ def __get_list(
Returns:
List[dict]: A list of dictionaries representing the retrieved elements.
"""

filters = dict(filters)
if filters.get("limit", None) is not None:
page_size = filters["limit"]
num_elements = filters["limit"]

if filters.get("offset", None) is not None:
offset = filters["offset"]

el_list = []
filters.update({"is_valid": True})
if num_elements is None:
Expand Down Expand Up @@ -976,3 +1000,42 @@ def get_certificate_encrypted_keys(
url = f"{self.server_url}/certificates/{certificate_id}/encrypted_keys/"
error_msg = f"Could not retrieve encrypted keys of certificate {certificate_id}"
return self.__get_list(url=url, filters=filters, error_msg=error_msg)

def get_benchmarks_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of benchmarks in the platform.

Returns:
int: count of all benchmarks.
"""
if is_owner:
url = f"{self.server_url}/me/benchmarks/"
else:
url = f"{self.server_url}/benchmarks/"
error_msg = "Could not retrieve benchmarks count"
return self.__get_count(url, filters=filters, error_msg=error_msg)

def get_cubes_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of MLCubes in the platform.

Returns:
int: count of all MLCubes.
"""
if is_owner:
url = f"{self.server_url}/me/mlcubes/"
else:
url = f"{self.server_url}/mlcubes/"
error_msg = "Could not retrieve mlcubes count"
return self.__get_count(url, filters=filters, error_msg=error_msg)

def get_datasets_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of datasets in the platform.

Returns:
int: count of all datasets.
"""
if is_owner:
url = f"{self.server_url}/me/datasets/"
else:
url = f"{self.server_url}/datasets/"
error_msg = "Could not retrieve datasets count"
return self.__get_count(url, filters=filters, error_msg=error_msg)
7 changes: 6 additions & 1 deletion cli/medperf/entities/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple
from medperf.commands.association.utils import (
get_experiment_associations,
get_user_associations,
Expand Down Expand Up @@ -85,6 +85,11 @@ def remote_prefilter(filters: dict) -> callable:
comms_fn = config.comms.get_user_benchmarks
return comms_fn

@staticmethod
def remote_prefilter_counter(filters: dict) -> Tuple[callable, bool]:
owner = "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]
return config.comms.get_benchmarks_count, owner

@classmethod
def get_models_uids(cls, benchmark_uid: int) -> List[int]:
"""Retrieves the list of models associated to the benchmark
Expand Down
7 changes: 6 additions & 1 deletion cli/medperf/entities/cube.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
from medperf.commands.association.utils import get_user_associations
from pydantic import Field

Expand Down Expand Up @@ -110,6 +110,11 @@ def remote_prefilter(filters: dict):

return comms_fn

@staticmethod
def remote_prefilter_counter(filters: dict) -> Tuple[callable, bool]:
owner = "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]
return config.comms.get_cubes_count, owner

@classmethod
def get(
cls,
Expand Down
7 changes: 6 additions & 1 deletion cli/medperf/entities/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from medperf.commands.association.utils import get_user_associations
import yaml
from pydantic import Field, validator
from typing import Optional, Union, List
from typing import Optional, Tuple, Union, List

from medperf.utils import remove_path
from medperf.entities.interface import Entity
Expand Down Expand Up @@ -121,6 +121,11 @@ def func():

return comms_fn

@staticmethod
def remote_prefilter_counter(filters: dict) -> Tuple[callable, bool]:
owner = "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]
return config.comms.get_datasets_count, owner

@classmethod
def get_benchmarks_associations(cls, dataset_uid: int) -> List[dict]:
"""Retrieves the list of benchmarks dataset is associated with
Expand Down
25 changes: 24 additions & 1 deletion cli/medperf/entities/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Union, Callable
from typing import List, Dict, Tuple, Union, Callable
from abc import ABC
import logging
import os
Expand Down Expand Up @@ -111,6 +111,18 @@ def remote_prefilter(filters: dict) -> callable:
"""
raise NotImplementedError

@staticmethod
def remote_prefilter_counter(filters: dict) -> Tuple[Callable[[dict], int], bool]:
"""Applies filtering logic that must be done before retrieving remote entities count

Args:
filters (dict): filters to apply

Returns:
callable: A function for retrieving remote entities count with the applied prefilters
"""
raise NotImplementedError

@classmethod
def get(
cls: Type[EntityType],
Expand Down Expand Up @@ -229,3 +241,14 @@ def display_dict(self) -> dict:
dict: the display dictionary
"""
raise NotImplementedError

@classmethod
def get_count(cls: Type[EntityType], filters: dict = {}) -> int:
"""Returns the count of items in the entity
Returns:
int: count of items
"""
logging.info(f"Retrieving the count of {cls.get_type()} entities")
comms_fn, is_owner = cls.remote_prefilter_counter(filters=filters)
count = comms_fn(filters=filters, is_owner=is_owner)
return count
58 changes: 49 additions & 9 deletions cli/medperf/web_ui/benchmarks/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,65 @@
def benchmarks_ui(
request: Request,
mine_only: bool = False,
page: int = 1,
page_size: int = 9,
ordering: str = "created_at_desc",
current_user: bool = Depends(check_user_ui),
):

if ordering == "created_at_asc":
order = "created_at"
elif ordering == "name_asc":
order = "name"
elif ordering == "name_desc":
order = "-name"
else:
order = "-created_at"

filters = {}
my_user_id = get_medperf_user_data()["id"]

if mine_only:
filters["owner"] = my_user_id

benchmarks = Benchmark.all(
filters=filters,
)
total_count = Benchmark.get_count(filters=filters)

# Pagination
offset = (page - 1) * page_size
filters["limit"] = page_size
filters["offset"] = offset

# Ordering
filters["ordering"] = order

benchmarks = Benchmark.all(filters=filters)

my_benchmarks = [b for b in benchmarks if b.owner == my_user_id]
other_benchmarks = [b for b in benchmarks if b.owner != my_user_id]
benchmarks = my_benchmarks + other_benchmarks

total_pages = (total_count + page_size - 1) // page_size

start_index = 0
end_index = 0
if total_count != 0:
start_index = offset + 1
end_index = min(offset + len(benchmarks), total_count)

benchmarks = sorted(benchmarks, key=lambda x: x.created_at, reverse=True)
# sort by (mine recent) (mine oldish), (other recent), (other oldish)
mine_benchmarks = [d for d in benchmarks if d.owner == my_user_id]
other_benchmarks = [d for d in benchmarks if d.owner != my_user_id]
benchmarks = mine_benchmarks + other_benchmarks
return templates.TemplateResponse(
"benchmark/benchmarks.html",
{"request": request, "benchmarks": benchmarks, "mine_only": mine_only},
{
"request": request,
"benchmarks": benchmarks,
"mine_only": mine_only,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
"ordering": ordering,
"total_count": total_count,
"start_index": start_index,
"end_index": end_index,
},
)


Expand Down
52 changes: 46 additions & 6 deletions cli/medperf/web_ui/containers/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,64 @@
def containers_ui(
request: Request,
mine_only: bool = False,
page: int = 1,
page_size: int = 9,
ordering: str = "created_at_desc",
current_user: bool = Depends(check_user_ui),
):
if ordering == "created_at_asc":
order = "created_at"
elif ordering == "name_asc":
order = "name"
elif ordering == "name_desc":
order = "-name"
else:
order = "-created_at"

filters = {}
my_user_id = get_medperf_user_data()["id"]

if mine_only:
filters["owner"] = my_user_id

containers = Cube.all(
filters=filters,
)
containers = sorted(containers, key=lambda x: x.created_at, reverse=True)
# sort by (mine recent) (mine oldish), (other recent), (other oldish)
total_count = Cube.get_count(filters=filters)

# Pagination
offset = (page - 1) * page_size
filters["limit"] = page_size
filters["offset"] = offset

# Ordering
filters["ordering"] = order

containers = Cube.all(filters=filters)

my_containers = [c for c in containers if c.owner == my_user_id]
other_containers = [c for c in containers if c.owner != my_user_id]
containers = my_containers + other_containers

total_pages = (total_count + page_size - 1) // page_size

start_index = 0
end_index = 0
if total_count != 0:
start_index = offset + 1
end_index = min(offset + len(containers), total_count)

return templates.TemplateResponse(
"container/containers.html",
{"request": request, "containers": containers, "mine_only": mine_only},
{
"request": request,
"containers": containers,
"mine_only": mine_only,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
"ordering": ordering,
"total_count": total_count,
"start_index": start_index,
"end_index": end_index,
},
)


Expand Down
56 changes: 48 additions & 8 deletions cli/medperf/web_ui/datasets/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,64 @@
def datasets_ui(
request: Request,
mine_only: bool = False,
page: int = 1,
page_size: int = 9,
ordering: str = "created_at_desc",
current_user: bool = Depends(check_user_ui),
):
if ordering == "created_at_asc":
order = "created_at"
elif ordering == "name_asc":
order = "name"
elif ordering == "name_desc":
order = "-name"
else:
order = "-created_at"

filters = {}
my_user_id = get_medperf_user_data()["id"]

if mine_only:
filters["owner"] = my_user_id
datasets = Dataset.all(
filters=filters,
)

datasets = sorted(datasets, key=lambda x: x.created_at, reverse=True)
# sort by (mine recent) (mine oldish), (other recent), (other oldish)
mine_datasets = [d for d in datasets if d.owner == my_user_id]
total_count = Dataset.get_count(filters=filters)

# Pagination
offset = (page - 1) * page_size
filters["limit"] = page_size
filters["offset"] = offset

# Ordering
filters["ordering"] = order

datasets = Dataset.all(filters=filters)

my_datasets = [d for d in datasets if d.owner == my_user_id]
other_datasets = [d for d in datasets if d.owner != my_user_id]
datasets = mine_datasets + other_datasets
datasets = my_datasets + other_datasets

total_pages = (total_count + page_size - 1) // page_size

start_index = 0
end_index = 0
if total_count != 0:
start_index = offset + 1
end_index = min(offset + len(datasets), total_count)

return templates.TemplateResponse(
"dataset/datasets.html",
{"request": request, "datasets": datasets, "mine_only": mine_only},
{
"request": request,
"datasets": datasets,
"mine_only": mine_only,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
"ordering": ordering,
"total_count": total_count,
"start_index": start_index,
"end_index": end_index,
},
)


Expand Down
Loading
Loading