diff --git a/flow/visualize/plot_ray_results.py b/flow/visualize/plot_ray_results.py index 36e335337..2fe00c638 100644 --- a/flow/visualize/plot_ray_results.py +++ b/flow/visualize/plot_ray_results.py @@ -9,20 +9,41 @@ Example usage ----- :: - python plot_ray_results.py .csv mean_reward max_reward + python plot_ray_results.py .csv .csv mean_reward max_reward """ import csv import argparse import matplotlib.pyplot as plt +import os from collections import defaultdict EXAMPLE_USAGE = 'plot_ray_results.py ' + \ - '~/ray_results/experiment-tag/experiment-name/seed-id/progress.csv ' + \ + '~/ray_results/experiment-tag/experiment-name/seed-id-1/progress.csv ' + \ + '~/ray_results/experiment-tag/experiment-name/seed-id-2/progress.csv ' + \ 'evaluation/return-average training/return-average' +def plot_multi_progresses(files_columns): + """Plot ray results from multiple csv files.""" + data = defaultdict(list) + plt.ion() + + filenames = [filename for filename in files_columns if '.csv' in filename] + columnnames = [column for column in files_columns if '.csv' not in column] + for filecsv in filenames: + data = plot_progress(filecsv, columnnames) + if not data: + return + dirname = os.path.dirname(filecsv) + for col_name, values in data.items(): + plt.plot(values, label=dirname) + plt.legend() + plt.show() + plt.savefig('testresult.png') + + def plot_progress(filepath, columns): """Plot ray results from a csv file. @@ -53,12 +74,7 @@ def plot_progress(filepath, columns): 'This column contains values that are not convertible to ' 'floats.'.format(__file__, col)) raise - - plt.ion() - for col_name, values in data.items(): - plt.plot(values, label=col_name) - plt.legend() - plt.show() + return data def create_parser(): @@ -74,9 +90,7 @@ def create_parser(): description='[Flow] Plots progress.csv file generated by ray.', epilog='Example usage:\n\t' + EXAMPLE_USAGE) - parser.add_argument('file', type=str, help='Path to the csv file.') - parser.add_argument( - 'columns', type=str, nargs='*', help='Names of the columns to plot.') + parser.add_argument('files_columns', type=str, nargs='*', help='Path to the csv files and columns to plot.') return parser @@ -84,4 +98,4 @@ def create_parser(): if __name__ == '__main__': parser = create_parser() args = parser.parse_args() - plot_progress(args.file, args.columns) + plot_multi_progresses(args.files_columns) diff --git a/tests/fast_tests/test_visualizers.py b/tests/fast_tests/test_visualizers.py index 7af413909..e12c6ce6e 100644 --- a/tests/fast_tests/test_visualizers.py +++ b/tests/fast_tests/test_visualizers.py @@ -330,22 +330,22 @@ def test_plot_ray_results(self): # test with one column args = parser.parse_args([file_path, 'episode_reward_mean']) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) # test with several columns args = parser.parse_args([file_path, 'episode_reward_mean', 'episode_reward_min', 'episode_reward_max']) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) # test with non-existing column name with self.assertRaises(KeyError): args = parser.parse_args([file_path, 'episode_reward']) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) # test with column containing non-float values with self.assertRaises(ValueError): args = parser.parse_args([file_path, 'info']) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) # test that script outputs available column names if none is given column_names = [ @@ -382,7 +382,7 @@ def test_plot_ray_results(self): temp_stdout = StringIO() with contextlib.redirect_stdout(temp_stdout): args = parser.parse_args([file_path]) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) output = temp_stdout.getvalue() for column in column_names: