diff --git a/README.md b/README.md index cb08ff7..ff59e27 100644 --- a/README.md +++ b/README.md @@ -329,6 +329,7 @@ my_custom_requests_session = Session() aerie_host = AerieHost( GRAPHQL_URL, GATEWAY_URL, + FILESTORE_URL, session=my_custom_requests_session ) aerie_host.authenticate(...) diff --git a/src/aerie_cli/aerie_client.py b/src/aerie_cli/aerie_client.py index 6c01fd7..811c8c7 100644 --- a/src/aerie_cli/aerie_client.py +++ b/src/aerie_cli/aerie_client.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Dict from typing import List +from typing import Optional from typing import Union from copy import deepcopy @@ -43,7 +44,11 @@ def __init__(self, aerie_host: AerieHost): """ self.aerie_host = aerie_host - def get_activity_plan_by_id(self, plan_id: int, full_args: str = None) -> ActivityPlanRead: + def get_activity_plan_by_id( + self, + plan_id: int, + full_args: str = None, + ) -> ActivityPlanRead: """Download activity plan from Aerie Args: @@ -160,11 +165,10 @@ def get_plan_id_by_sim_id(self, simulation_dataset_id: int) -> int: } """ resp = self.aerie_host.post_to_graphql( - get_plan_id_query, - simulation_dataset_id=simulation_dataset_id + get_plan_id_query, simulation_dataset_id=simulation_dataset_id ) - return resp['simulation']['plan']['id'] - + return resp["simulation"]["plan"]["id"] + def get_tag_id_by_name(self, tag_name: str): get_tags_by_name_query = """ query GetTagByName($name: String) { @@ -174,7 +178,7 @@ def get_tag_id_by_name(self, tag_name: str): } """ - #make default color of tag white + # make default color of tag white create_new_tag = """ mutation CreateNewTag($name: String, $color: String = "#FFFFFF") { insert_tags_one(object: {name: $name, color: $color}) { @@ -184,17 +188,17 @@ def get_tag_id_by_name(self, tag_name: str): """ resp = self.aerie_host.post_to_graphql( - get_tags_by_name_query, - name=tag_name + get_tags_by_name_query, + name=tag_name, ) - #if a tag with the specified name exists then returns the ID, else creates a new tag with this name - if len(resp) > 0: + # if a tag with the specified name exists then returns the ID, else creates a new tag with this name + if len(resp) > 0: return resp[0]["id"] - else: + else: new_tag_resp = self.aerie_host.post_to_graphql( - create_new_tag, - name=tag_name + create_new_tag, + name=tag_name, ) return new_tag_resp["id"] @@ -209,15 +213,15 @@ def add_plan_tag(self, plan_id: int, tag_name: str): } } """ - - #add tag to plan + + # add tag to plan resp = self.aerie_host.post_to_graphql( - add_tag_to_plan, - plan_id=plan_id, - tag_id=self.get_tag_id_by_name(tag_name) + add_tag_to_plan, + plan_id=plan_id, + tag_id=self.get_tag_id_by_name(tag_name), ) - return resp['returning'][0] + return resp["returning"][0] def create_activity_plan( self, model_id: int, plan_to_create: ActivityPlanCreate @@ -239,10 +243,10 @@ def create_activity_plan( plan_id = plan_resp["id"] plan_revision = plan_resp["revision"] - #add plan tags if exists from plan_to_create + # add plan tags if exists from plan_to_create for tag in plan_to_create.tags: self.add_plan_tag(plan_id, tag["tag"]["name"]) - + # This loop exists to make sure all anchor IDs are updated as necessary # Deep copy activities so we can augment and pop from the list @@ -292,7 +296,7 @@ def create_activity_plan( update_simulation_mutation, plan_id=plan_id, simulation_start_time=simulation_start_time, - simulation_end_time=simulation_end_time + simulation_end_time=simulation_end_time, ) return plan_id @@ -308,7 +312,7 @@ def create_activity(self, activity_to_create: Activity, plan_id: int) -> int: """ resp = self.aerie_host.post_to_graphql( insert_activity_mutation, - activity=api_activity_create.to_dict() + activity=api_activity_create.to_dict(), ) activity_id = resp["id"] @@ -318,7 +322,7 @@ def update_activity( self, activity_id: int, activity_to_update: Activity, - plan_id: int + plan_id: int, ) -> int: activity_dict: Dict = activity_to_update.to_api_update().to_dict() update_activity_mutation = """ @@ -338,7 +342,7 @@ def update_activity( ) return resp["id"] - def get_all_activity_presets(self, m_id:int) -> List: + def get_all_activity_presets(self, m_id: int) -> List: get_all_presets_query = """ query ($model_id: Int!) { activity_presets (where: {model_id:{_eq:$model_id}}){ @@ -351,10 +355,7 @@ def get_all_activity_presets(self, m_id:int) -> List: } """ - resp = self.aerie_host.post_to_graphql( - get_all_presets_query, - model_id=m_id - ) + resp = self.aerie_host.post_to_graphql(get_all_presets_query, model_id=m_id) return resp def upload_activity_presets(self, upload_obj): @@ -375,8 +376,7 @@ def upload_activity_presets(self, upload_obj): }""" resp = self.aerie_host.post_to_graphql( - upload_activity_presets_query, - object = upload_obj + upload_activity_presets_query, object=upload_obj ) return resp["returning"] @@ -410,19 +410,23 @@ def exec_sim_query(): return sim_dataset_id def get_resource_timelines(self, plan_id: int): - samples = self.get_resource_samples(self.get_simulation_dataset_ids_by_plan_id(plan_id)[0]) + samples = self.get_resource_samples( + self.get_simulation_dataset_ids_by_plan_id(plan_id)[0] + ) api_resource_timeline = ApiResourceSampleResults.from_dict(samples) return api_resource_timeline - def get_resource_samples(self, simulation_dataset_id: int, state_names: List=None): + def get_resource_samples( + self, simulation_dataset_id: int, state_names: List = None + ): """Pull resource samples from a simulation dataset, optionally filtering for specific states Each resource's values are returned in a list of points {x: