33import terminaltables
44
55from paperspace import logger , constants , client , config
6- from paperspace .commands import CommandBase
76from paperspace .workspace import S3WorkspaceHandler
87from paperspace .logger import log_response
98from paperspace .utils import get_terminal_lines
9+ from . import common
1010
1111experiments_api = client .API (config .CONFIG_EXPERIMENTS_HOST , headers = client .default_headers )
1212
1313
14- class ExperimentCommand (CommandBase ):
14+ class ExperimentCommand (common . CommandBase ):
1515 def __init__ (self , workspace_handler = None , ** kwargs ):
1616 super (ExperimentCommand , self ).__init__ (** kwargs )
1717 self ._workspace_handler = workspace_handler or S3WorkspaceHandler (experiments_api = self .api , logger = self .logger )
@@ -68,52 +68,36 @@ def stop_experiment(experiment_id, api=experiments_api):
6868 log_response (response , "Experiment stopped" , "Unknown error while stopping the experiment" )
6969
7070
71- class ListExperimentsCommand (object ):
72- def __init__ ( self , api = experiments_api , logger_ = logger ):
73- self . api = api
74- self . logger = logger_
71+ class ListExperimentsCommand (common . ListCommand ):
72+ @ property
73+ def request_url ( self ):
74+ return "/experiments/"
7575
76- def execute (self , project_ids = None ):
77- project_ids = project_ids or []
78- params = self ._get_query_params (project_ids )
79- response = self .api .get ("/experiments/" , params = params )
80-
81- try :
82- data = response .json ()
83- if not response .ok :
84- self .logger .log_error_response (data )
85- return
86-
87- experiments = self ._get_experiments_list (data , bool (project_ids ))
88- except (ValueError , KeyError ) as e :
89- self .logger .error ("Error while parsing response data: {}" .format (e ))
90- else :
91- self ._log_experiments_list (experiments )
92-
93- @staticmethod
94- def _get_query_params (project_ids ):
76+ def _get_request_params (self , kwargs ):
9577 params = {"limit" : - 1 } # so the API sends back full list without pagination
96- for i , experiment_id in enumerate (project_ids ):
97- key = "projectHandle[{}]" .format (i )
98- params [key ] = experiment_id
78+
79+ project_ids = kwargs .get ("project_ids" )
80+ if project_ids :
81+ for i , experiment_id in enumerate (project_ids ):
82+ key = "projectHandle[{}]" .format (i )
83+ params [key ] = experiment_id
9984
10085 return params
10186
102- @staticmethod
103- def _make_experiments_list_table (experiments ):
87+ def _get_table_data (self , experiments ):
10488 data = [("Name" , "ID" , "Status" )]
10589 for experiment in experiments :
10690 name = experiment ["templateHistory" ]["params" ].get ("name" )
10791 handle = experiment ["handle" ]
10892 status = constants .ExperimentState .get_state_str (experiment ["state" ])
10993 data .append ((name , handle , status ))
11094
111- ascii_table = terminaltables .AsciiTable (data )
112- table_string = ascii_table .table
113- return table_string
95+ return data
96+
97+ def _get_objects (self , response , kwargs ):
98+ data = super (ListExperimentsCommand , self )._get_objects (response , kwargs )
11499
115- @staticmethod
116- def _get_experiments_list (data , filtered = False ):
100+ filtered = bool (kwargs .get ("project_ids" ))
117101 if not filtered : # If filtering by project ID response data has different format...
118102 return data ["data" ]
119103
@@ -123,16 +107,6 @@ def _get_experiments_list(data, filtered=False):
123107 experiments .append (experiment )
124108 return experiments
125109
126- def _log_experiments_list (self , experiments ):
127- if not experiments :
128- self .logger .warning ("No experiments found" )
129- else :
130- table_str = self ._make_experiments_list_table (experiments )
131- if len (table_str .splitlines ()) > get_terminal_lines ():
132- pydoc .pager (table_str )
133- else :
134- self .logger .log (table_str )
135-
136110
137111def _make_details_table (experiment ):
138112 if experiment ["experimentTypeId" ] == constants .ExperimentType .SINGLE_NODE :
0 commit comments