Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions flow/visualize/plot_ray_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,41 @@
Example usage
-----
::
python plot_ray_results.py </path/to/file>.csv mean_reward max_reward
python plot_ray_results.py </path/to/file1>.csv </path/to/file2>.csv mean_reward max_reward
"""

import csv
import argparse
import matplotlib.pyplot as plt
import os
from collections import defaultdict


EXAMPLE_USAGE = 'plot_ray_results.py ' + \
'~/ray_results/experiment-tag/experiment-name/seed-id/progress.csv ' + \
'~/ray_results/experiment-tag/experiment-name/seed-id-1/progress.csv ' + \
'~/ray_results/experiment-tag/experiment-name/seed-id-2/progress.csv ' + \
'evaluation/return-average training/return-average'


def plot_multi_progresses(files_columns):
"""Plot ray results from multiple csv files."""
data = defaultdict(list)
plt.ion()

filenames = [filename for filename in files_columns if '.csv' in filename]
columnnames = [column for column in files_columns if '.csv' not in column]
for filecsv in filenames:
data = plot_progress(filecsv, columnnames)
if not data:
return
dirname = os.path.dirname(filecsv)
for col_name, values in data.items():
plt.plot(values, label=dirname)
plt.legend()
plt.show()
plt.savefig('testresult.png')


def plot_progress(filepath, columns):
"""Plot ray results from a csv file.

Expand Down Expand Up @@ -53,12 +74,7 @@ def plot_progress(filepath, columns):
'This column contains values that are not convertible to '
'floats.'.format(__file__, col))
raise

plt.ion()
for col_name, values in data.items():
plt.plot(values, label=col_name)
plt.legend()
plt.show()
return data


def create_parser():
Expand All @@ -74,14 +90,12 @@ def create_parser():
description='[Flow] Plots progress.csv file generated by ray.',
epilog='Example usage:\n\t' + EXAMPLE_USAGE)

parser.add_argument('file', type=str, help='Path to the csv file.')
parser.add_argument(
'columns', type=str, nargs='*', help='Names of the columns to plot.')
parser.add_argument('files_columns', type=str, nargs='*', help='Path to the csv files and columns to plot.')

return parser


if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
plot_progress(args.file, args.columns)
plot_multi_progresses(args.files_columns)
10 changes: 5 additions & 5 deletions tests/fast_tests/test_visualizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,22 +330,22 @@ def test_plot_ray_results(self):

# test with one column
args = parser.parse_args([file_path, 'episode_reward_mean'])
prr.plot_progress(args.file, args.columns)
prr.plot_multi_progresses(args.files_columns)

# test with several columns
args = parser.parse_args([file_path, 'episode_reward_mean',
'episode_reward_min', 'episode_reward_max'])
prr.plot_progress(args.file, args.columns)
prr.plot_multi_progresses(args.files_columns)

# test with non-existing column name
with self.assertRaises(KeyError):
args = parser.parse_args([file_path, 'episode_reward'])
prr.plot_progress(args.file, args.columns)
prr.plot_multi_progresses(args.files_columns)

# test with column containing non-float values
with self.assertRaises(ValueError):
args = parser.parse_args([file_path, 'info'])
prr.plot_progress(args.file, args.columns)
prr.plot_multi_progresses(args.files_columns)

# test that script outputs available column names if none is given
column_names = [
Expand Down Expand Up @@ -382,7 +382,7 @@ def test_plot_ray_results(self):
temp_stdout = StringIO()
with contextlib.redirect_stdout(temp_stdout):
args = parser.parse_args([file_path])
prr.plot_progress(args.file, args.columns)
prr.plot_multi_progresses(args.files_columns)
output = temp_stdout.getvalue()

for column in column_names:
Expand Down