|
6 | 6 | import numpy as np |
7 | 7 | from numpy.lib.recfunctions import append_fields |
8 | 8 | from pandas import DataFrame, RangeIndex |
| 9 | +import pandas as pd |
9 | 10 | from root_numpy import root2array, list_trees |
10 | 11 | import fnmatch |
11 | 12 | from root_numpy import list_branches |
@@ -312,6 +313,15 @@ def convert_to_dataframe(array, start_index=None): |
312 | 313 | assert len(columns) == len(df.columns), (columns, df.columns) |
313 | 314 | df = df.reindex_axis(columns, axis=1, copy=False) |
314 | 315 |
|
| 316 | + # Convert categorical columns back to categories |
| 317 | + for c in df.columns: |
| 318 | + match = re.match(r'^__rpCaT\*([^\*]+)\*(True|False)\*', c) |
| 319 | + if match: |
| 320 | + real_name, ordered = match.groups() |
| 321 | + categories = c.split('*')[3:] |
| 322 | + df[c] = pd.Categorical.from_codes(df[c], categories, ordered={'True': True, 'False': False}[ordered]) |
| 323 | + df.rename(index=str, columns={c: real_name}, inplace=True) |
| 324 | + |
315 | 325 | return df |
316 | 326 |
|
317 | 327 |
|
@@ -353,12 +363,25 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg |
353 | 363 | from root_numpy import array2root |
354 | 364 | # We don't want to modify the user's DataFrame here, so we make a shallow copy |
355 | 365 | df_ = df.copy(deep=False) |
| 366 | + |
356 | 367 | if store_index: |
357 | 368 | name = df_.index.name |
358 | 369 | if name is None: |
359 | 370 | # Handle the case where the index has no name |
360 | 371 | name = '' |
361 | 372 | df_['__index__' + name] = df_.index |
| 373 | + |
| 374 | + # Convert categorical columns into something root_numpy can serialise |
| 375 | + for col in df_.select_dtypes(['category']).columns: |
| 376 | + name_components = ['__rpCaT', col, str(df_[col].cat.ordered)] |
| 377 | + name_components.extend(df_[col].cat.categories) |
| 378 | + if ['*' not in c for c in name_components]: |
| 379 | + sep = '*' |
| 380 | + else: |
| 381 | + raise ValueError('Unable to find suitable separator for columns') |
| 382 | + df_[col] = df_[col].cat.codes |
| 383 | + df_.rename(index=str, columns={col: sep.join(name_components)}, inplace=True) |
| 384 | + |
362 | 385 | arr = df_.to_records(index=False) |
363 | 386 | array2root(arr, path, key, mode=mode, *args, **kwargs) |
364 | 387 |
|
|
0 commit comments