Skip to content
18 changes: 17 additions & 1 deletion cli/medperf/commands/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,29 @@ def list(
False, "--unregistered", help="Get unregistered benchmarks"
),
mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"),
name: str = typer.Option(None, "--name", help="Filter by name"),
owner: int = typer.Option(None, "--owner", help="Filter by owner"),
state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"),
is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"),
is_active: bool = typer.Option(None, "--active/--inactive", help="Filter by active status"),
data_prep: int = typer.Option(None, "-d", "--data-preparation-mlcube", help="Filter by Data Preparation MLCube"),
):
"""List benchmarks"""
filters = {
"name": name,
"owner": owner,
"state": state,
"is_valid": is_valid,
"is_active": is_active,
"data_preparation_mlcube": data_prep
}

EntityList.run(
Benchmark,
fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"],
fields=["UID", "Name", "Description", "Data Preparation MLCube", "State", "Approval Status", "Registered"],
unregistered=unregistered,
mine_only=mine,
**filters,
)


Expand Down
8 changes: 8 additions & 0 deletions cli/medperf/commands/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def list(
mlcube: int = typer.Option(
None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube"
),
name: str = typer.Option(None, "--name", help="Filter by name"),
owner: int = typer.Option(None, "--owner", help="Filter by owner"),
state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"),
is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"),
):
"""List datasets"""
EntityList.run(
Expand All @@ -32,6 +36,10 @@ def list(
unregistered=unregistered,
mine_only=mine,
mlcube=mlcube,
name=name,
owner=owner,
state=state,
is_valid=is_valid,
)


Expand Down
3 changes: 2 additions & 1 deletion cli/medperf/commands/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def run(
Args:
unregistered (bool, optional): Display only local unregistered results. Defaults to False.
mine_only (bool, optional): Display all registered current-user results. Defaults to False.
kwargs (dict): Additional parameters for filtering entity lists.
kwargs (dict): Additional parameters for filtering entity lists. Keys with None will be filtered out.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
entity_list = EntityList(
entity_class, fields, unregistered, mine_only, **kwargs
)
Expand Down
8 changes: 8 additions & 0 deletions cli/medperf/commands/mlcube/mlcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ def list(
False, "--unregistered", help="Get unregistered mlcubes"
),
mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"),
name: str = typer.Option(None, "--name", "-n", help="Filter out by MLCube Name"),
owner: int = typer.Option(None, "--owner", help="Filter by owner ID"),
state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"),
is_active: bool = typer.Option(None, "--active/--inactive", help="Filter by active status"),
):
"""List mlcubes"""
EntityList.run(
Cube,
fields=["UID", "Name", "State", "Registered"],
unregistered=unregistered,
mine_only=mine,
name=name,
owner=owner,
state=state,
is_active=is_active,
)


Expand Down
8 changes: 8 additions & 0 deletions cli/medperf/commands/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def list(
benchmark: int = typer.Option(
None, "--benchmark", "-b", help="Get results for a given benchmark"
),
model: int = typer.Option(
None, "--owner", "-o", help="Get results for a given model"
),
dataset: int = typer.Option(
None, "--dataset", "-d", help="Get reuslts for a given dataset"
),
):
"""List results"""
EntityList.run(
Expand All @@ -77,6 +83,8 @@ def list(
unregistered=unregistered,
mine_only=mine,
benchmark=benchmark,
model=model,
dataset=dataset,
)


Expand Down
70 changes: 43 additions & 27 deletions cli/medperf/comms/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __get_list(
page_size=config.default_page_size,
offset=0,
binary_reduction=False,
filters={},
):
"""Retrieves a list of elements from a URL by iterating over pages until num_elements is obtained.
If num_elements is None, then iterates until all elements have been retrieved.
Expand All @@ -104,7 +105,9 @@ def __get_list(
num_elements = float("inf")

while len(el_list) < num_elements:
paginated_url = f"{url}?limit={page_size}&offset={offset}"
filters.update({"limit": page_size, "offset": offset})
query_str = "&".join([f"{k}={v}" for k, v in filters.items()])
paginated_url = f"{url}?{query_str}"
res = self.__auth_get(paginated_url)
if res.status_code != 200:
if not binary_reduction:
Expand Down Expand Up @@ -152,13 +155,13 @@ def get_current_user(self):
res = self.__auth_get(f"{self.server_url}/me/")
return res.json()

def get_benchmarks(self) -> List[dict]:
def get_benchmarks(self, filters={}) -> List[dict]:
"""Retrieves all benchmarks in the platform.

Returns:
List[dict]: all benchmarks information.
"""
bmks = self.__get_list(f"{self.server_url}/benchmarks/")
bmks = self.__get_list(f"{self.server_url}/benchmarks/", filters=filters)
return bmks

def get_benchmark(self, benchmark_uid: int) -> dict:
Expand All @@ -179,7 +182,7 @@ def get_benchmark(self, benchmark_uid: int) -> dict:
)
return res.json()

def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]:
def get_benchmark_model_associations(self, benchmark_uid: int, filters={}) -> List[int]:
"""Retrieves all the model associations of a benchmark.

Args:
Expand All @@ -188,25 +191,28 @@ def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]:
Returns:
list[int]: List of benchmark model associations
"""
assocs = self.__get_list(f"{self.server_url}/benchmarks/{benchmark_uid}/models")
assocs = self.__get_list(
f"{self.server_url}/benchmarks/{benchmark_uid}/models",
filters=filters,
)
return filter_latest_associations(assocs, "model_mlcube")

def get_user_benchmarks(self) -> List[dict]:
def get_user_benchmarks(self, filters={}) -> List[dict]:
"""Retrieves all benchmarks created by the user

Returns:
List[dict]: Benchmarks data
"""
bmks = self.__get_list(f"{self.server_url}/me/benchmarks/")
bmks = self.__get_list(f"{self.server_url}/me/benchmarks/", filters=filters)
return bmks

def get_cubes(self) -> List[dict]:
def get_cubes(self, filters={}) -> List[dict]:
"""Retrieves all MLCubes in the platform

Returns:
List[dict]: List containing the data of all MLCubes
"""
cubes = self.__get_list(f"{self.server_url}/mlcubes/")
cubes = self.__get_list(f"{self.server_url}/mlcubes/", filters=filters)
return cubes

def get_cube_metadata(self, cube_uid: int) -> dict:
Expand All @@ -227,13 +233,13 @@ def get_cube_metadata(self, cube_uid: int) -> dict:
)
return res.json()

def get_user_cubes(self) -> List[dict]:
def get_user_cubes(self, filters={}) -> List[dict]:
"""Retrieves metadata from all cubes registered by the user

Returns:
List[dict]: List of dictionaries containing the mlcubes registration information
"""
cubes = self.__get_list(f"{self.server_url}/me/mlcubes/")
cubes = self.__get_list(f"{self.server_url}/me/mlcubes/", filters=filters)
return cubes

def upload_benchmark(self, benchmark_dict: dict) -> int:
Expand Down Expand Up @@ -268,13 +274,13 @@ def upload_mlcube(self, mlcube_body: dict) -> int:
raise CommunicationRetrievalError(f"Could not upload the mlcube: {details}")
return res.json()

def get_datasets(self) -> List[dict]:
def get_datasets(self, filters={}) -> List[dict]:
"""Retrieves all datasets in the platform

Returns:
List[dict]: List of data from all datasets
"""
dsets = self.__get_list(f"{self.server_url}/datasets/")
dsets = self.__get_list(f"{self.server_url}/datasets/", filters=filters)
return dsets

def get_dataset(self, dset_uid: int) -> dict:
Expand All @@ -295,13 +301,13 @@ def get_dataset(self, dset_uid: int) -> dict:
)
return res.json()

def get_user_datasets(self) -> dict:
def get_user_datasets(self, filters={}) -> dict:
"""Retrieves all datasets registered by the user

Returns:
dict: dictionary with the contents of each dataset registration query
"""
dsets = self.__get_list(f"{self.server_url}/me/datasets/")
dsets = self.__get_list(f"{self.server_url}/me/datasets/", filters=filters)
return dsets

def upload_dataset(self, reg_dict: dict) -> int:
Expand All @@ -320,13 +326,13 @@ def upload_dataset(self, reg_dict: dict) -> int:
raise CommunicationRequestError(f"Could not upload the dataset: {details}")
return res.json()

def get_results(self) -> List[dict]:
def get_results(self, filters={}) -> List[dict]:
"""Retrieves all results

Returns:
List[dict]: List of results
"""
res = self.__get_list(f"{self.server_url}/results")
res = self.__get_list(f"{self.server_url}/results", filters=filters)
if res.status_code != 200:
log_response_error(res)
details = format_errors_dict(res.json())
Expand All @@ -351,16 +357,16 @@ def get_result(self, result_uid: int) -> dict:
)
return res.json()

def get_user_results(self) -> dict:
def get_user_results(self, filters={}) -> dict:
"""Retrieves all results registered by the user

Returns:
dict: dictionary with the contents of each result registration query
"""
results = self.__get_list(f"{self.server_url}/me/results/")
results = self.__get_list(f"{self.server_url}/me/results/", filters=filters)
return results

def get_benchmark_results(self, benchmark_id: int) -> dict:
def get_benchmark_results(self, benchmark_id: int, filters={}) -> dict:
"""Retrieves all results for a given benchmark

Args:
Expand All @@ -370,7 +376,8 @@ def get_benchmark_results(self, benchmark_id: int) -> dict:
dict: dictionary with the contents of each result in the specified benchmark
"""
results = self.__get_list(
f"{self.server_url}/benchmarks/{benchmark_id}/results"
f"{self.server_url}/benchmarks/{benchmark_id}/results",
filters=filters,
)
return results

Expand Down Expand Up @@ -472,22 +479,28 @@ def set_mlcube_association_approval(
f"Could not approve association between mlcube {mlcube_uid} and benchmark {benchmark_uid}: {details}"
)

def get_datasets_associations(self) -> List[dict]:
def get_datasets_associations(self, filters={}) -> List[dict]:
"""Get all dataset associations related to the current user

Returns:
List[dict]: List containing all associations information
"""
assocs = self.__get_list(f"{self.server_url}/me/datasets/associations/")
assocs = self.__get_list(
f"{self.server_url}/me/datasets/associations/",
filters=filters,
)
return filter_latest_associations(assocs, "dataset")

def get_cubes_associations(self) -> List[dict]:
def get_cubes_associations(self, filters={}) -> List[dict]:
"""Get all cube associations related to the current user

Returns:
List[dict]: List containing all associations information
"""
assocs = self.__get_list(f"{self.server_url}/me/mlcubes/associations/")
assocs = self.__get_list(
f"{self.server_url}/me/mlcubes/associations/",
filters=filters,
)
return filter_latest_associations(assocs, "model_mlcube")

def set_mlcube_association_priority(
Expand Down Expand Up @@ -519,7 +532,7 @@ def update_dataset(self, dataset_id: int, data: dict):
raise CommunicationRequestError(f"Could not update dataset: {details}")
return res.json()

def get_mlcube_datasets(self, mlcube_id: int) -> dict:
def get_mlcube_datasets(self, mlcube_id: int, filters={}) -> dict:
"""Retrieves all datasets that have the specified mlcube as the prep mlcube

Args:
Expand All @@ -529,7 +542,10 @@ def get_mlcube_datasets(self, mlcube_id: int) -> dict:
dict: dictionary with the contents of each dataset
"""

datasets = self.__get_list(f"{self.server_url}/mlcubes/{mlcube_id}/datasets/")
datasets = self.__get_list(
f"{self.server_url}/mlcubes/{mlcube_id}/datasets/",
filters=filters,
)
return datasets

def get_user(self, user_id: int) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion cli/medperf/entities/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def all(
@classmethod
def __remote_all(cls: Type[EntityType], filters: dict) -> List[EntityType]:
comms_fn = cls.remote_prefilter(filters)
entity_meta = comms_fn()
entity_meta = comms_fn(filters)
entities = [cls(**meta) for meta in entity_meta]
return entities

Expand Down
1 change: 1 addition & 0 deletions cli/medperf/entities/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def remote_prefilter(filters: dict) -> callable:
comms_fn = config.comms.get_results
if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]:
comms_fn = config.comms.get_user_results
del filters["owner"]
if "benchmark" in filters and filters["benchmark"] is not None:
bmk = filters["benchmark"]

Expand Down
Loading