|
2 | 2 | """ |
3 | 3 | A module that extends pandas to support the ROOT data format. |
4 | 4 | """ |
| 5 | +from collections import Counter |
5 | 6 |
|
6 | 7 | import numpy as np |
7 | 8 | from numpy.lib.recfunctions import append_fields |
@@ -95,11 +96,11 @@ def get_nonscalar_columns(array): |
95 | 96 | def get_matching_variables(branches, patterns, fail=True): |
96 | 97 | # Convert branches to a set to make x "in branches" O(1) on average |
97 | 98 | branches = set(branches) |
98 | | - patterns = set(patterns) |
99 | 99 | # Find any trivial matches |
100 | | - selected = list(branches.intersection(patterns)) |
| 100 | + selected = sorted(branches.intersection(patterns), |
| 101 | + key=lambda s: patterns.index(s)) |
101 | 102 | # Any matches that weren't trivial need to be looped over... |
102 | | - for pattern in patterns.difference(selected): |
| 103 | + for pattern in set(patterns).difference(selected): |
103 | 104 | found = False |
104 | 105 | # Avoid using fnmatch if the pattern if possible |
105 | 106 | if re.findall(r'(\*)|(\?)|(\[.*\])|(\[\!.*\])', pattern): |
@@ -317,7 +318,7 @@ def convert_to_dataframe(array, start_index=None): |
317 | 318 | # Filter to remove __index__ columns |
318 | 319 | columns = [c for c in array.dtype.names if c in df.columns] |
319 | 320 | assert len(columns) == len(df.columns), (columns, df.columns) |
320 | | - df = df.reindex_axis(columns, axis=1, copy=False) |
| 321 | + df = df.reindex(columns, axis=1, copy=False) |
321 | 322 |
|
322 | 323 | # Convert categorical columns back to categories |
323 | 324 | for c in df.columns: |
@@ -366,6 +367,11 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg |
366 | 367 | else: |
367 | 368 | raise ValueError('Unknown mode: {}. Must be "a" or "w".'.format(mode)) |
368 | 369 |
|
| 370 | + column_name_counts = Counter(df.columns) |
| 371 | + if max(column_name_counts.values()) > 1: |
| 372 | + raise ValueError('DataFrame contains duplicated column names: ' + |
| 373 | + ' '.join({k for k, v in column_name_counts.items() if v > 1})) |
| 374 | + |
369 | 375 | from root_numpy import array2tree |
370 | 376 | # We don't want to modify the user's DataFrame here, so we make a shallow copy |
371 | 377 | df_ = df.copy(deep=False) |
|
0 commit comments