diff --git a/src/scilpy/cli/scil_json_convert_entries_to_xlsx.py b/src/scilpy/cli/scil_json_convert_entries_to_xlsx.py index 98b5903f3..21e801970 100755 --- a/src/scilpy/cli/scil_json_convert_entries_to_xlsx.py +++ b/src/scilpy/cli/scil_json_convert_entries_to_xlsx.py @@ -17,7 +17,7 @@ from scilpy.version import version_string dps_dpp = ['data_per_streamline_keys', 'data_per_point_keys'] - +required_keys = {'mean', 'std'} def _get_all_bundle_names(stats): bnames = set() @@ -82,7 +82,7 @@ def _get_stats_parse_function(stats, stats_over_population): return _parse_lesion elif type(first_bundle_substat) is dict: sub_keys = list(first_bundle_substat.keys()) - if set(sub_keys) == set(['mean', 'std']): # when you have mean and std per stats + if required_keys.issubset(sub_keys): # when you have mean and std per stats if stats_over_population: return _parse_per_label_population_stats else: @@ -126,7 +126,7 @@ def _parse_scalar_stats(stats, subs, bundles): return dataframes, df_names -def _parse_scalar_meanstd(stats, subs, bundles): +def _parse_scalar_meanstd(stats, subs, bundles, optional_keys): metric_names = _get_metrics_names(stats) nb_subs = len(subs) @@ -135,6 +135,20 @@ def _parse_scalar_meanstd(stats, subs, bundles): means = np.full((nb_subs, nb_bundles, nb_metrics), np.NaN) stddev = np.full((nb_subs, nb_bundles, nb_metrics), np.NaN) + + found_keys = set() + for sub_dict in stats.values(): + for bundle_dict in sub_dict.values(): + for m_stat in bundle_dict.values(): + if isinstance(m_stat, dict): + found_keys.update(m_stat.keys()) + keys_present = set(optional_keys).intersection(found_keys) + optional_arrays = {} + + for key in keys_present: + optional_arrays[key] = np.full( + (nb_subs, nb_bundles, nb_metrics), + np.nan) for sub_id, sub_name in enumerate(subs): for bundle_id, bundle_name in enumerate(bundles): @@ -147,6 +161,9 @@ def _parse_scalar_meanstd(stats, subs, bundles): if m_stat is not None: means[sub_id, bundle_id, metric_id] = m_stat['mean'] stddev[sub_id, bundle_id, metric_id] = m_stat['std'] + for key in keys_present: + optional_arrays[key][sub_id, bundle_id, metric_id] = \ + m_stat.get(key, np.nan) dataframes = [] df_names = [] @@ -157,8 +174,16 @@ def _parse_scalar_meanstd(stats, subs, bundles): df_names.append(metric_name + "_mean") dataframes.append(pd.DataFrame(data=stddev[:, :, metric_id], - index=subs, columns=bundles)) + index=subs, columns=bundles)) df_names.append(metric_name + "_std") + + for key in keys_present: + dataframes.append( + pd.DataFrame( + data=optional_arrays[key][:, :, metric_id], + index=subs, + columns=bundles)) + df_names.append(f"{metric_name}_{key}") return dataframes, df_names @@ -199,7 +224,7 @@ def _parse_scalar_lesions(stats, subs, bundles): return dataframes, df_names -def _parse_stats(stats, subs, bundles): +def _parse_stats(stats, subs, bundles, optional_keys=None): nb_subs = len(subs) nb_bundles = len(bundles) @@ -408,7 +433,8 @@ def _parse_per_label_population_stats(stats, bundles, metrics): def _create_xlsx_from_json(json_path, xlsx_path, sort_subs=True, sort_bundles=True, ignored_bundles_fpath=None, - stats_over_population=False): + stats_over_population=False, + optional_keys=None): with open(json_path, 'r') as json_file: stats = json.load(json_file) @@ -428,7 +454,7 @@ def _create_xlsx_from_json(json_path, xlsx_path, cur_stats_func = _get_stats_parse_function(stats, stats_over_population) - dataframes, df_names = cur_stats_func(stats, subs, bundle_names) + dataframes, df_names = cur_stats_func(stats, subs, bundle_names, optional_keys) if len(dataframes): _write_dataframes(dataframes, df_names, xlsx_path) @@ -447,6 +473,9 @@ def _build_arg_parser(): p.add_argument('--no_sort_subs', action='store_false', help='If set, subjects won\'t be sorted alphabetically.') + + p.add_argument('--extra_key', nargs='+', default=[], + help='Optional keys to export (must be associated to numeric values only)') p.add_argument('--no_sort_bundles', action='store_false', help='If set, bundles won\'t be sorted alphabetically.') @@ -468,6 +497,7 @@ def main(): parser = _build_arg_parser() args = parser.parse_args() logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + extra_keys = set(args.extra_key) assert_inputs_exist(parser, args.in_json) assert_outputs_exist(parser, args, args.out_xlsx) @@ -476,7 +506,8 @@ def main(): sort_subs=args.no_sort_subs, sort_bundles=args.no_sort_bundles, ignored_bundles_fpath=args.ignore_bundles, - stats_over_population=args.stats_over_population) + stats_over_population=args.stats_over_population, + optional_keys=extra_keys) if __name__ == "__main__":