Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions src/scilpy/cli/scil_json_convert_entries_to_xlsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from scilpy.version import version_string

dps_dpp = ['data_per_streamline_keys', 'data_per_point_keys']

required_keys = {'mean', 'std'}

def _get_all_bundle_names(stats):
bnames = set()
Expand Down Expand Up @@ -82,7 +82,7 @@ def _get_stats_parse_function(stats, stats_over_population):
return _parse_lesion
elif type(first_bundle_substat) is dict:
sub_keys = list(first_bundle_substat.keys())
if set(sub_keys) == set(['mean', 'std']): # when you have mean and std per stats
if required_keys.issubset(sub_keys): # when you have mean and std per stats
if stats_over_population:
return _parse_per_label_population_stats
else:
Expand Down Expand Up @@ -126,7 +126,7 @@ def _parse_scalar_stats(stats, subs, bundles):
return dataframes, df_names


def _parse_scalar_meanstd(stats, subs, bundles):
def _parse_scalar_meanstd(stats, subs, bundles, optional_keys):
metric_names = _get_metrics_names(stats)

nb_subs = len(subs)
Expand All @@ -135,6 +135,20 @@ def _parse_scalar_meanstd(stats, subs, bundles):

means = np.full((nb_subs, nb_bundles, nb_metrics), np.NaN)
stddev = np.full((nb_subs, nb_bundles, nb_metrics), np.NaN)

found_keys = set()
for sub_dict in stats.values():
for bundle_dict in sub_dict.values():
for m_stat in bundle_dict.values():
if isinstance(m_stat, dict):
found_keys.update(m_stat.keys())
keys_present = set(optional_keys).intersection(found_keys)
optional_arrays = {}

for key in keys_present:
optional_arrays[key] = np.full(
(nb_subs, nb_bundles, nb_metrics),
np.nan)

for sub_id, sub_name in enumerate(subs):
for bundle_id, bundle_name in enumerate(bundles):
Expand All @@ -147,6 +161,9 @@ def _parse_scalar_meanstd(stats, subs, bundles):
if m_stat is not None:
means[sub_id, bundle_id, metric_id] = m_stat['mean']
stddev[sub_id, bundle_id, metric_id] = m_stat['std']
for key in keys_present:
optional_arrays[key][sub_id, bundle_id, metric_id] = \
m_stat.get(key, np.nan)

dataframes = []
df_names = []
Expand All @@ -157,8 +174,16 @@ def _parse_scalar_meanstd(stats, subs, bundles):
df_names.append(metric_name + "_mean")

dataframes.append(pd.DataFrame(data=stddev[:, :, metric_id],
index=subs, columns=bundles))
index=subs, columns=bundles))
df_names.append(metric_name + "_std")

for key in keys_present:
dataframes.append(
pd.DataFrame(
data=optional_arrays[key][:, :, metric_id],
index=subs,
columns=bundles))
df_names.append(f"{metric_name}_{key}")

return dataframes, df_names

Expand Down Expand Up @@ -199,7 +224,7 @@ def _parse_scalar_lesions(stats, subs, bundles):
return dataframes, df_names


def _parse_stats(stats, subs, bundles):
def _parse_stats(stats, subs, bundles, optional_keys=None):
nb_subs = len(subs)
nb_bundles = len(bundles)

Expand Down Expand Up @@ -408,7 +433,8 @@ def _parse_per_label_population_stats(stats, bundles, metrics):
def _create_xlsx_from_json(json_path, xlsx_path,
sort_subs=True, sort_bundles=True,
ignored_bundles_fpath=None,
stats_over_population=False):
stats_over_population=False,
optional_keys=None):
with open(json_path, 'r') as json_file:
stats = json.load(json_file)

Expand All @@ -428,7 +454,7 @@ def _create_xlsx_from_json(json_path, xlsx_path,

cur_stats_func = _get_stats_parse_function(stats, stats_over_population)

dataframes, df_names = cur_stats_func(stats, subs, bundle_names)
dataframes, df_names = cur_stats_func(stats, subs, bundle_names, optional_keys)

if len(dataframes):
_write_dataframes(dataframes, df_names, xlsx_path)
Expand All @@ -447,6 +473,9 @@ def _build_arg_parser():

p.add_argument('--no_sort_subs', action='store_false',
help='If set, subjects won\'t be sorted alphabetically.')

p.add_argument('--extra_key', nargs='+', default=[],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'Optional keys to export (must be associated to numeric values only).'

help='Optional keys to export (must be associated to numeric values only)')

p.add_argument('--no_sort_bundles', action='store_false',
help='If set, bundles won\'t be sorted alphabetically.')
Expand All @@ -468,6 +497,7 @@ def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))
extra_keys = set(args.extra_key)

assert_inputs_exist(parser, args.in_json)
assert_outputs_exist(parser, args, args.out_xlsx)
Expand All @@ -476,7 +506,8 @@ def main():
sort_subs=args.no_sort_subs,
sort_bundles=args.no_sort_bundles,
ignored_bundles_fpath=args.ignore_bundles,
stats_over_population=args.stats_over_population)
stats_over_population=args.stats_over_population,
optional_keys=extra_keys)


if __name__ == "__main__":
Expand Down