Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.

Commit ccd748c

Browse files
committed
Refactor list commands to use share common parent class
1 parent da6e9fa commit ccd748c

20 files changed

+169
-256
lines changed

paperspace/cli/deployments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def get_deployments_list(api_key=None, **filters):
105105
del_if_value_is_none(filters)
106106
deployments_api = client.API(config.CONFIG_HOST, api_key=api_key)
107107
command = deployments_commands.ListDeploymentsCommand(api=deployments_api)
108-
command.execute(filters)
108+
command.execute(filters=filters)
109109

110110

111111
@deployments.command("update", help="Update deployment properties")

paperspace/cli/experiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def stop_experiment(experiment_id, api_key):
294294
def list_experiments(project_ids, api_key):
295295
experiments_api = client.API(config.CONFIG_EXPERIMENTS_HOST, api_key=api_key)
296296
command = experiments_commands.ListExperimentsCommand(api=experiments_api)
297-
command.execute(project_ids)
297+
command.execute(project_ids=project_ids)
298298

299299

300300
@experiments.command("details", help="Show detail of an experiment")

paperspace/cli/jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def list_jobs(api_key, **filters):
6262
common.del_if_value_is_none(filters)
6363
jobs_api = client.API(config.CONFIG_HOST, api_key=api_key)
6464
command = jobs_commands.ListJobsCommand(api=jobs_api)
65-
command.execute(filters)
65+
command.execute(filters=filters)
6666

6767

6868
@jobs_group.command("create", help="Create job")

paperspace/cli/machines.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ def list_machines(api_key, params, **kwargs):
360360

361361
machines_api = client.API(config.CONFIG_HOST, api_key=api_key)
362362
command = machines_commands.ListMachinesCommand(api=machines_api)
363-
command.execute(params or kwargs)
363+
filters = params or kwargs
364+
command.execute(filters=filters)
364365

365366

366367
restart_machine_help = "Restart an individual machine. If the machine is already restarting, this action will " \

paperspace/cli/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ def list_models(api_key, **filters):
2727
common.del_if_value_is_none(filters)
2828
models_api = client.API(config.CONFIG_HOST, api_key=api_key)
2929
command = models_commands.ListModelsCommand(api=models_api)
30-
command.execute(filters)
30+
command.execute(filters=filters)

paperspace/commands/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
from paperspace import logger
2-
3-
4-
class CommandBase(object):
5-
def __init__(self, api=None, logger_=logger):
6-
self.api = api
7-
self.logger = logger_

paperspace/commands/common.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pydoc
2+
3+
import terminaltables
4+
5+
from paperspace import logger
6+
from paperspace.utils import get_terminal_lines
7+
8+
9+
class CommandBase(object):
10+
def __init__(self, api=None, logger_=logger):
11+
self.api = api
12+
self.logger = logger_
13+
14+
15+
class ListCommand(CommandBase):
16+
@property
17+
def request_url(self):
18+
raise NotImplementedError()
19+
20+
def execute(self, **kwargs):
21+
response = self._get_response(kwargs)
22+
23+
try:
24+
if not response.ok:
25+
self.logger.log_error_response(response.json())
26+
return
27+
28+
objects = self._get_objects(response, kwargs)
29+
except (ValueError, KeyError) as e:
30+
self.logger.error("Error while parsing response data: {}".format(e))
31+
else:
32+
self._log_objects_list(objects)
33+
34+
def _log_objects_list(self, objects):
35+
if not objects:
36+
self.logger.warning("No data found")
37+
return
38+
39+
table_data = self._get_table_data(objects)
40+
table_str = self._make_table(table_data)
41+
if len(table_str.splitlines()) > get_terminal_lines():
42+
pydoc.pager(table_str)
43+
else:
44+
self.logger.log(table_str)
45+
46+
def _get_objects(self, response, kwargs):
47+
data = response.json()
48+
return data
49+
50+
def _get_response(self, kwargs):
51+
json_ = self._get_request_json(kwargs)
52+
params = self._get_request_params(kwargs)
53+
response = self.api.get(self.request_url, json=json_, params=params)
54+
return response
55+
56+
def _get_table_data(self, objects):
57+
raise NotImplementedError()
58+
59+
@staticmethod
60+
def _make_table(table_data):
61+
ascii_table = terminaltables.AsciiTable(table_data)
62+
table_string = ascii_table.table
63+
return table_string
64+
65+
def _get_request_json(self, kwargs):
66+
return None
67+
68+
def _get_request_params(self, kwargs):
69+
return None

paperspace/commands/deployments.py

Lines changed: 15 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import pydoc
2-
31
import terminaltables
42

5-
from paperspace import config, version, client, logger
6-
from paperspace.commands import CommandBase
7-
from paperspace.utils import get_terminal_lines
3+
from paperspace import config, version, client
4+
from paperspace.commands import common
5+
from paperspace.commands.common import CommandBase
86

97
default_headers = {"X-API-Key": config.PAPERSPACE_API_KEY,
108
"ps_client_name": "paperspace-python",
@@ -39,51 +37,25 @@ def execute(self, kwargs):
3937
"Unknown error during deployment")
4038

4139

42-
class ListDeploymentsCommand(_DeploymentCommandBase):
43-
def execute(self, filters=None):
44-
json_ = self._get_request_json(filters)
45-
response = self.api.get("/deployments/getDeploymentList/", json=json_)
40+
class ListDeploymentsCommand(common.ListCommand):
41+
@property
42+
def request_url(self):
43+
return "/deployments/getDeploymentList/"
4644

47-
try:
48-
data = response.json()
49-
if not response.ok:
50-
self.logger.log_error_response(data)
51-
return
52-
deployments = self._get_deployments_list(response)
53-
except (ValueError, KeyError) as e:
54-
self.logger.error("Error while parsing response data: {}".format(e))
55-
else:
56-
self._log_deployments_list(deployments)
57-
58-
@staticmethod
59-
def _get_request_json(filters):
45+
def _get_request_json(self, kwargs):
46+
filters = kwargs.get("filters")
6047
if not filters:
6148
return None
6249

6350
json_ = {"filter": {"where": {"and": [filters]}}}
6451
return json_
6552

66-
@staticmethod
67-
def _get_deployments_list(response):
68-
if not response.ok:
69-
raise ValueError("Unknown error")
70-
71-
data = response.json()["deploymentList"]
72-
logger.debug(data)
73-
return data
53+
def _get_objects(self, response, kwargs):
54+
data = super(ListDeploymentsCommand, self)._get_objects(response, kwargs)
55+
objects = data["deploymentList"]
56+
return objects
7457

75-
def _log_deployments_list(self, deployments):
76-
if not deployments:
77-
self.logger.warning("No deployments found")
78-
else:
79-
table_str = self._make_deployments_list_table(deployments)
80-
if len(table_str.splitlines()) > get_terminal_lines():
81-
pydoc.pager(table_str)
82-
else:
83-
self.logger.log(table_str)
84-
85-
@staticmethod
86-
def _make_deployments_list_table(deployments):
58+
def _get_table_data(self, deployments):
8759
data = [("Name", "ID", "Endpoint", "Api Type", "Deployment Type")]
8860
for deployment in deployments:
8961
name = deployment.get("name")
@@ -93,9 +65,7 @@ def _make_deployments_list_table(deployments):
9365
deployment_type = deployment.get("deploymentType")
9466
data.append((name, id_, endpoint, api_type, deployment_type))
9567

96-
ascii_table = terminaltables.AsciiTable(data)
97-
table_string = ascii_table.table
98-
return table_string
68+
return data
9969

10070

10171
class UpdateDeploymentCommand(_DeploymentCommandBase):

paperspace/commands/experiments.py

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import terminaltables
44

55
from paperspace import logger, constants, client, config
6-
from paperspace.commands import CommandBase
76
from paperspace.workspace import S3WorkspaceHandler
87
from paperspace.logger import log_response
98
from paperspace.utils import get_terminal_lines
9+
from . import common
1010

1111
experiments_api = client.API(config.CONFIG_EXPERIMENTS_HOST, headers=client.default_headers)
1212

1313

14-
class ExperimentCommand(CommandBase):
14+
class ExperimentCommand(common.CommandBase):
1515
def __init__(self, workspace_handler=None, **kwargs):
1616
super(ExperimentCommand, self).__init__(**kwargs)
1717
self._workspace_handler = workspace_handler or S3WorkspaceHandler(experiments_api=self.api, logger=self.logger)
@@ -68,52 +68,36 @@ def stop_experiment(experiment_id, api=experiments_api):
6868
log_response(response, "Experiment stopped", "Unknown error while stopping the experiment")
6969

7070

71-
class ListExperimentsCommand(object):
72-
def __init__(self, api=experiments_api, logger_=logger):
73-
self.api = api
74-
self.logger = logger_
71+
class ListExperimentsCommand(common.ListCommand):
72+
@property
73+
def request_url(self):
74+
return "/experiments/"
7575

76-
def execute(self, project_ids=None):
77-
project_ids = project_ids or []
78-
params = self._get_query_params(project_ids)
79-
response = self.api.get("/experiments/", params=params)
80-
81-
try:
82-
data = response.json()
83-
if not response.ok:
84-
self.logger.log_error_response(data)
85-
return
86-
87-
experiments = self._get_experiments_list(data, bool(project_ids))
88-
except (ValueError, KeyError) as e:
89-
self.logger.error("Error while parsing response data: {}".format(e))
90-
else:
91-
self._log_experiments_list(experiments)
92-
93-
@staticmethod
94-
def _get_query_params(project_ids):
76+
def _get_request_params(self, kwargs):
9577
params = {"limit": -1} # so the API sends back full list without pagination
96-
for i, experiment_id in enumerate(project_ids):
97-
key = "projectHandle[{}]".format(i)
98-
params[key] = experiment_id
78+
79+
project_ids = kwargs.get("project_ids")
80+
if project_ids:
81+
for i, experiment_id in enumerate(project_ids):
82+
key = "projectHandle[{}]".format(i)
83+
params[key] = experiment_id
9984

10085
return params
10186

102-
@staticmethod
103-
def _make_experiments_list_table(experiments):
87+
def _get_table_data(self, experiments):
10488
data = [("Name", "ID", "Status")]
10589
for experiment in experiments:
10690
name = experiment["templateHistory"]["params"].get("name")
10791
handle = experiment["handle"]
10892
status = constants.ExperimentState.get_state_str(experiment["state"])
10993
data.append((name, handle, status))
11094

111-
ascii_table = terminaltables.AsciiTable(data)
112-
table_string = ascii_table.table
113-
return table_string
95+
return data
96+
97+
def _get_objects(self, response, kwargs):
98+
data = super(ListExperimentsCommand, self)._get_objects(response, kwargs)
11499

115-
@staticmethod
116-
def _get_experiments_list(data, filtered=False):
100+
filtered = bool(kwargs.get("project_ids"))
117101
if not filtered: # If filtering by project ID response data has different format...
118102
return data["data"]
119103

@@ -123,16 +107,6 @@ def _get_experiments_list(data, filtered=False):
123107
experiments.append(experiment)
124108
return experiments
125109

126-
def _log_experiments_list(self, experiments):
127-
if not experiments:
128-
self.logger.warning("No experiments found")
129-
else:
130-
table_str = self._make_experiments_list_table(experiments)
131-
if len(table_str.splitlines()) > get_terminal_lines():
132-
pydoc.pager(table_str)
133-
else:
134-
self.logger.log(table_str)
135-
136110

137111
def _make_details_table(experiment):
138112
if experiment["experimentTypeId"] == constants.ExperimentType.SINGLE_NODE:

paperspace/commands/jobs.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from click import style
55

66
from paperspace import config, client
7-
from paperspace.commands import CommandBase
7+
from paperspace.commands.common import CommandBase
88
from paperspace.utils import get_terminal_lines
99
from paperspace.workspace import S3WorkspaceHandler
10+
from . import common
1011

1112

1213
class JobsCommandBase(CommandBase):
@@ -45,33 +46,17 @@ def execute(self, job_id):
4546
"Unknown error while stopping job")
4647

4748

48-
class ListJobsCommand(JobsCommandBase):
49-
def execute(self, filters=None):
50-
json_ = filters or None
51-
response = self.api.get("/jobs/getJobs/", json=json_)
52-
53-
try:
54-
data = response.json()
55-
if not response.ok:
56-
self.logger.log_error_response(data)
57-
return
58-
except (ValueError, KeyError) as e:
59-
self.logger.error("Error while parsing response data: {}".format(e))
60-
else:
61-
self._log_jobs_list(data)
49+
class ListJobsCommand(common.ListCommand):
50+
@property
51+
def request_url(self):
52+
return "/jobs/getJobs/"
6253

63-
def _log_jobs_list(self, data):
64-
if not data:
65-
self.logger.warning("No jobs found")
66-
else:
67-
table_str = self._make_table(data)
68-
if len(table_str.splitlines()) > get_terminal_lines():
69-
pydoc.pager(table_str)
70-
else:
71-
self.logger.log(table_str)
54+
def _get_request_json(self, kwargs):
55+
filters = kwargs.get("filters")
56+
json_ = filters or None
57+
return json_
7258

73-
@staticmethod
74-
def _make_table(jobs):
59+
def _get_table_data(self, jobs):
7560
data = [("ID", "Name", "Project", "Cluster", "Machine Type", "Created")]
7661
for job in jobs:
7762
id_ = job.get("id")
@@ -82,9 +67,7 @@ def _make_table(jobs):
8267
created = job.get("dtCreated")
8368
data.append((id_, name, project, cluster, machine_type, created))
8469

85-
ascii_table = terminaltables.AsciiTable(data)
86-
table_string = ascii_table.table
87-
return table_string
70+
return data
8871

8972

9073
class JobLogsCommand(CommandBase):

0 commit comments

Comments
 (0)