From 907ef6921ae9a3c6cfe78bfa7f4e826dd6685063 Mon Sep 17 00:00:00 2001 From: zpymyyn Date: Sun, 22 Mar 2020 13:18:46 +0200 Subject: [PATCH 1/8] extend the plotting function to multiple files Major demand for plotting would be same value from different training, instead of different values from a same training. Therefore I extend this file's function to enable plotting multiple columns from multiple trainings. --- flow/visualize/plot_ray_results.py | 36 +++++++++++++++++++----------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/flow/visualize/plot_ray_results.py b/flow/visualize/plot_ray_results.py index 36e335337..0732870fb 100644 --- a/flow/visualize/plot_ray_results.py +++ b/flow/visualize/plot_ray_results.py @@ -9,7 +9,7 @@ Example usage ----- :: - python plot_ray_results.py .csv mean_reward max_reward + python plot_ray_results.py .csv .csv mean_reward max_reward """ import csv @@ -19,9 +19,26 @@ EXAMPLE_USAGE = 'plot_ray_results.py ' + \ - '~/ray_results/experiment-tag/experiment-name/seed-id/progress.csv ' + \ + '~/ray_results/experiment-tag/experiment-name/seed-id-1/progress.csv ' + \ + '~/ray_results/experiment-tag/experiment-name/seed-id-2/progress.csv ' + \ 'evaluation/return-average training/return-average' +def plot_multi_progresses(files_columns): + """Plot ray results from multiple csv files.""" + data = defaultdict(list) + plt.ion() + + filenames = [filename for filename in files_columns if '.csv' in filename] + columnnames = [column for column in files_columns if '.csv' not in column] + for filecsv in filenames: + data = plot_progress(filecsv, columnnames) + if not data: + return + for col_name, values in data.items(): + plt.plot(values, label=col_name+'/'+filecsv) + plt.legend() + plt.show() + plt.savefig('testresult.png') def plot_progress(filepath, columns): """Plot ray results from a csv file. @@ -53,13 +70,7 @@ def plot_progress(filepath, columns): 'This column contains values that are not convertible to ' 'floats.'.format(__file__, col)) raise - - plt.ion() - for col_name, values in data.items(): - plt.plot(values, label=col_name) - plt.legend() - plt.show() - + return data def create_parser(): """Parse visualization options user can specify in command line. @@ -74,9 +85,8 @@ def create_parser(): description='[Flow] Plots progress.csv file generated by ray.', epilog='Example usage:\n\t' + EXAMPLE_USAGE) - parser.add_argument('file', type=str, help='Path to the csv file.') - parser.add_argument( - 'columns', type=str, nargs='*', help='Names of the columns to plot.') + parser.add_argument('files_columns', type=str, nargs='*', + help='Path to the csv files, and names of the columns to plot.') return parser @@ -84,4 +94,4 @@ def create_parser(): if __name__ == '__main__': parser = create_parser() args = parser.parse_args() - plot_progress(args.file, args.columns) + plot_multi_progresses(args.files_columns) From 001e3ae87d7da08d782b1379ea5535b36a42a5f7 Mon Sep 17 00:00:00 2001 From: zpymyyn Date: Sun, 22 Mar 2020 13:23:03 +0200 Subject: [PATCH 2/8] Update plot_ray_results.py --- flow/visualize/plot_ray_results.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flow/visualize/plot_ray_results.py b/flow/visualize/plot_ray_results.py index 0732870fb..3739ff8a8 100644 --- a/flow/visualize/plot_ray_results.py +++ b/flow/visualize/plot_ray_results.py @@ -85,8 +85,7 @@ def create_parser(): description='[Flow] Plots progress.csv file generated by ray.', epilog='Example usage:\n\t' + EXAMPLE_USAGE) - parser.add_argument('files_columns', type=str, nargs='*', - help='Path to the csv files, and names of the columns to plot.') + parser.add_argument('files_columns', type=str, nargs='*', help='Path to the csv files, and names of the columns to plot.') return parser From eec57a8f9ebb7cc3cfa83e5e35996b00029d0271 Mon Sep 17 00:00:00 2001 From: zpymyyn Date: Sun, 22 Mar 2020 13:26:27 +0200 Subject: [PATCH 3/8] Update plot_ray_results.py --- flow/visualize/plot_ray_results.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flow/visualize/plot_ray_results.py b/flow/visualize/plot_ray_results.py index 3739ff8a8..e0a8f674d 100644 --- a/flow/visualize/plot_ray_results.py +++ b/flow/visualize/plot_ray_results.py @@ -85,7 +85,7 @@ def create_parser(): description='[Flow] Plots progress.csv file generated by ray.', epilog='Example usage:\n\t' + EXAMPLE_USAGE) - parser.add_argument('files_columns', type=str, nargs='*', help='Path to the csv files, and names of the columns to plot.') + parser.add_argument('files_columns', type=str, nargs='*', help='Path to the csv files and columns to plot.') return parser From 85f521562e522c4b7505346f71a6f757ffb28960 Mon Sep 17 00:00:00 2001 From: zpymyyn Date: Sun, 22 Mar 2020 13:32:22 +0200 Subject: [PATCH 4/8] Update plot_ray_results.py --- flow/visualize/plot_ray_results.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flow/visualize/plot_ray_results.py b/flow/visualize/plot_ray_results.py index e0a8f674d..24b91fd46 100644 --- a/flow/visualize/plot_ray_results.py +++ b/flow/visualize/plot_ray_results.py @@ -23,15 +23,16 @@ '~/ray_results/experiment-tag/experiment-name/seed-id-2/progress.csv ' + \ 'evaluation/return-average training/return-average' + def plot_multi_progresses(files_columns): """Plot ray results from multiple csv files.""" data = defaultdict(list) plt.ion() - + filenames = [filename for filename in files_columns if '.csv' in filename] columnnames = [column for column in files_columns if '.csv' not in column] for filecsv in filenames: - data = plot_progress(filecsv, columnnames) + data = plot_progress(filecsv, columnnames) if not data: return for col_name, values in data.items(): @@ -40,6 +41,7 @@ def plot_multi_progresses(files_columns): plt.show() plt.savefig('testresult.png') + def plot_progress(filepath, columns): """Plot ray results from a csv file. @@ -72,6 +74,7 @@ def plot_progress(filepath, columns): raise return data + def create_parser(): """Parse visualization options user can specify in command line. From 78144c2febb3bcbc352876a948592df7220ae500 Mon Sep 17 00:00:00 2001 From: zpymyyn Date: Sun, 22 Mar 2020 13:37:11 +0200 Subject: [PATCH 5/8] Update plot_ray_results.py --- flow/visualize/plot_ray_results.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flow/visualize/plot_ray_results.py b/flow/visualize/plot_ray_results.py index 24b91fd46..557d65e89 100644 --- a/flow/visualize/plot_ray_results.py +++ b/flow/visualize/plot_ray_results.py @@ -28,7 +28,7 @@ def plot_multi_progresses(files_columns): """Plot ray results from multiple csv files.""" data = defaultdict(list) plt.ion() - + filenames = [filename for filename in files_columns if '.csv' in filename] columnnames = [column for column in files_columns if '.csv' not in column] for filecsv in filenames: @@ -41,7 +41,7 @@ def plot_multi_progresses(files_columns): plt.show() plt.savefig('testresult.png') - + def plot_progress(filepath, columns): """Plot ray results from a csv file. From 25d157dbd55d20159dfa03c79367c4a59a354ee9 Mon Sep 17 00:00:00 2001 From: zpymyyn Date: Sun, 22 Mar 2020 14:01:12 +0200 Subject: [PATCH 6/8] Update test_visualizers.py --- tests/fast_tests/test_visualizers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/fast_tests/test_visualizers.py b/tests/fast_tests/test_visualizers.py index 7af413909..17f3a0bf2 100644 --- a/tests/fast_tests/test_visualizers.py +++ b/tests/fast_tests/test_visualizers.py @@ -330,22 +330,22 @@ def test_plot_ray_results(self): # test with one column args = parser.parse_args([file_path, 'episode_reward_mean']) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) # test with several columns args = parser.parse_args([file_path, 'episode_reward_mean', 'episode_reward_min', 'episode_reward_max']) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) # test with non-existing column name with self.assertRaises(KeyError): args = parser.parse_args([file_path, 'episode_reward']) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) # test with column containing non-float values with self.assertRaises(ValueError): args = parser.parse_args([file_path, 'info']) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) # test that script outputs available column names if none is given column_names = [ From 8f953fd19fe44ffc5b9648239ab3c00cbbe3498a Mon Sep 17 00:00:00 2001 From: zpymyyn Date: Sun, 22 Mar 2020 14:19:40 +0200 Subject: [PATCH 7/8] Update test_visualizers.py --- tests/fast_tests/test_visualizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fast_tests/test_visualizers.py b/tests/fast_tests/test_visualizers.py index 17f3a0bf2..e12c6ce6e 100644 --- a/tests/fast_tests/test_visualizers.py +++ b/tests/fast_tests/test_visualizers.py @@ -382,7 +382,7 @@ def test_plot_ray_results(self): temp_stdout = StringIO() with contextlib.redirect_stdout(temp_stdout): args = parser.parse_args([file_path]) - prr.plot_progress(args.file, args.columns) + prr.plot_multi_progresses(args.files_columns) output = temp_stdout.getvalue() for column in column_names: From fc875011723d02783fe9c7723ddcceb1a91feeaf Mon Sep 17 00:00:00 2001 From: zpymyyn Date: Sun, 22 Mar 2020 14:48:01 +0200 Subject: [PATCH 8/8] update legend with exp dirname --- flow/visualize/plot_ray_results.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flow/visualize/plot_ray_results.py b/flow/visualize/plot_ray_results.py index 557d65e89..2fe00c638 100644 --- a/flow/visualize/plot_ray_results.py +++ b/flow/visualize/plot_ray_results.py @@ -15,6 +15,7 @@ import csv import argparse import matplotlib.pyplot as plt +import os from collections import defaultdict @@ -35,8 +36,9 @@ def plot_multi_progresses(files_columns): data = plot_progress(filecsv, columnnames) if not data: return + dirname = os.path.dirname(filecsv) for col_name, values in data.items(): - plt.plot(values, label=col_name+'/'+filecsv) + plt.plot(values, label=dirname) plt.legend() plt.show() plt.savefig('testresult.png')