Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Commit 2539294

Browse files
authored
Merge pull request #69 from chrisburr/support-categories
Add support writing/reading categorical columns from pandas
2 parents 4ba1279 + 7546bfc commit 2539294

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

root_pandas/readwrite.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from numpy.lib.recfunctions import append_fields
88
from pandas import DataFrame, RangeIndex
9+
import pandas as pd
910
from root_numpy import root2array, list_trees
1011
import fnmatch
1112
from root_numpy import list_branches
@@ -312,6 +313,15 @@ def convert_to_dataframe(array, start_index=None):
312313
assert len(columns) == len(df.columns), (columns, df.columns)
313314
df = df.reindex_axis(columns, axis=1, copy=False)
314315

316+
# Convert categorical columns back to categories
317+
for c in df.columns:
318+
match = re.match(r'^__rpCaT\*([^\*]+)\*(True|False)\*', c)
319+
if match:
320+
real_name, ordered = match.groups()
321+
categories = c.split('*')[3:]
322+
df[c] = pd.Categorical.from_codes(df[c], categories, ordered={'True': True, 'False': False}[ordered])
323+
df.rename(index=str, columns={c: real_name}, inplace=True)
324+
315325
return df
316326

317327

@@ -353,12 +363,25 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg
353363
from root_numpy import array2root
354364
# We don't want to modify the user's DataFrame here, so we make a shallow copy
355365
df_ = df.copy(deep=False)
366+
356367
if store_index:
357368
name = df_.index.name
358369
if name is None:
359370
# Handle the case where the index has no name
360371
name = ''
361372
df_['__index__' + name] = df_.index
373+
374+
# Convert categorical columns into something root_numpy can serialise
375+
for col in df_.select_dtypes(['category']).columns:
376+
name_components = ['__rpCaT', col, str(df_[col].cat.ordered)]
377+
name_components.extend(df_[col].cat.categories)
378+
if ['*' not in c for c in name_components]:
379+
sep = '*'
380+
else:
381+
raise ValueError('Unable to find suitable separator for columns')
382+
df_[col] = df_[col].cat.codes
383+
df_.rename(index=str, columns={col: sep.join(name_components)}, inplace=True)
384+
362385
arr = df_.to_records(index=False)
363386
array2root(arr, path, key, mode=mode, *args, **kwargs)
364387

0 commit comments

Comments
 (0)