-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli.py
More file actions
196 lines (157 loc) · 6.26 KB
/
cli.py
File metadata and controls
196 lines (157 loc) · 6.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from db import Database
import matplotlib.pyplot as plt
import os.path
import pandas as pd
accepted_commands = {"exit", "filter", "help", "plot", "ratings", "refresh", "save", "stats", "view"}
def cli_run(db: Database) -> None:
"""Creates and runs basic CLI to access, filter and save/load movie data."""
cache_df = db.query_db()
while(True):
command = input("What would you like to do?: ").lower()
if command in accepted_commands:
if command == "exit":
break
else:
cache_df = process_command(command, db, cache_df)
else:
print("Invalid command, please try again.")
def process_command(command, db: Database, cache_df: pd.DataFrame) -> pd.DataFrame:
"""Recieves and processes various command line commands."""
match command:
case "filter":
return filter_command(db)
case "help":
help_command()
case "plot":
plot_command(db, cache_df)
case "ratings":
ratings_command(db)
case "refresh":
return refresh_command(db)
case "save":
save_command(cache_df)
case "stats":
stats_command(db, cache_df)
case "view":
view_command(cache_df)
return cache_df
def filter_command(db: Database) -> pd.DataFrame:
"""Filters incoming DataFrame based on user input and returns the result."""
def get_genres():
"""Returns user input genres."""
return str(input("Enter genre: ")).title().split()
def get_ratings():
"""Returns user input min and max ratings."""
while(True):
try:
rating_min = float(input("Enter minimum rating: "))
except ValueError:
print("Please enter a vald number!")
else:
break
while(True):
while(True):
try:
rating_max = float(input("Enter maximum rating: "))
except ValueError:
print("Please enter a vald number!")
else:
break
if rating_max >= rating_min:
break
print(f"Please enter greater than or equal to {rating_min}")
return rating_min, rating_max
type = str(input("What would you like to filter? ('genre', 'rating' or 'both'): ")).lower()
params = {}
if type in ('genre', 'g', 'both', 'b'):
params['genres'] = get_genres()
if type in ('rating', 'r', 'both', 'b'):
params['min_rating'], params['max_rating'] = get_ratings()
if len(params) >= 1:
df = db.filter_query(opt_params=params)
view_command(df)
return df
else:
print("That is not a valid command, please try again.")
return refresh_command(db)
def help_command() -> None:
"""Dynamically outputs valid CLI commands."""
string = "Use any of the following commands: "
for index, command in enumerate(sorted(accepted_commands)):
string += command
if not index == len(accepted_commands) - 1:
string += ", "
print(string)
def plot_command(db: Database, cache_df: pd.DataFrame) -> None:
input_str = str(input("Use filtered data (f) or full dataset (full)? "))
if input_str == 'f':
ax = cache_df.explode('genres').groupby('genres').size().plot(kind="bar")
elif input_str == 'full':
ax = db.genre_stats_query().plot(kind="bar", x='genre', y='count')
else:
print("That is not a valid command. Please try again!")
return
ax.set_title("Genre Counts")
ax.set_xlabel("Genre")
ax.set_ylabel("Count")
plt.tight_layout()
plt.savefig('genres.png')
print("Plot image saved successfully: genres.png")
def ratings_command(db: Database) -> None:
print("Get user ratings for specified movie.")
try:
movie_id = int(input("Enter movie id: "))
except ValueError:
print("That is not a valid movie id!")
return
df = db.ratings_query(movie_id=movie_id)
if len(df) > 0:
title = df['title'].iloc[0]
id = df['id'].iloc[0]
genres = df['genres'].iloc[0]
average = df['avg'].iloc[0]
print("User ratings (0.0 - 5.0) for:")
print(f"\nTitle: {title} ({id}) | Genres: {genres}\n")
print(df[['userId', 'rating']])
print(f"\nAvearge rating: {average}")
else:
print(f"Couldn't find any results for {movie_id}.")
def refresh_command(db: Database) -> pd.DataFrame:
"""Placeholder for more involved refresh command if needed. Currently just returns input df."""
return db.query_db()
def save_command(df: pd.DataFrame) -> None:
"""Saves to valid CSV/JSON file based on user input."""
def check_if_overwrite(name):
if os.path.isfile(name):
if not str(input("Would you like to overwrite this file? ('y'/'n'): ")).lower() == 'y':
print("Please try again.")
return False
return True
filename = str(input("Enter file name: "))
file_type = str(input("Enter file type to save. ('csv' or 'json'): ")).lower()
if file_type in ('csv', 'c'):
filename += '.csv'
if check_if_overwrite(filename):
df.to_csv(filename)
elif file_type in ('json', 'j'):
filename += '.json'
if check_if_overwrite(filename):
df.to_json(filename)
else:
print('Invalid file type. Please try again.')
def stats_command(db: Database, cache_df: pd.DataFrame) -> None:
"""Prints basic genre and rating stats from incoming dataset."""
input_str = str(input("Use filtered data (f) or full dataset (full)? "))
if input_str == 'f':
flat = cache_df.explode('genres')
genre_counts = flat.groupby('genres').size()
genre_means = flat.groupby('genres')['vote_average'].mean()
print("\n", pd.DataFrame({'Count': genre_counts, 'Avg Rating': genre_means}), "\n")
elif input_str == "full":
cache_df = db.genre_stats_query()
print(cache_df)
else:
print("That is not a valid command. Please try again.")
def view_command(df: pd.DataFrame) -> None:
"""Prints dataset in an easy to read way."""
print("\n", df, "\n")