Skip to content

Commit 9916708

Browse files
committed
FEAT: added support for dataframes with MultiIndex in columns in from_frame (closes #466)
1 parent 635b948 commit 9916708

File tree

3 files changed

+133
-25
lines changed

3 files changed

+133
-25
lines changed

doc/source/changes/version_0_35.rst.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ Miscellaneous improvements
9292

9393
>>> arr.plot.bar(stack='gender')
9494

95+
* :py:obj:`from_frame()` and :py:obj:`asarray()` now support Pandas DataFrames
96+
with more than one level (row) of columns (closes :issue:`466`).
97+
9598
* :py:obj:`Array.to_frame()` gained an ``ncolaxes`` argument to control how many
9699
axes should be used as columns (defaults to 1, as before).
97100

larray/inout/pandas.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from larray.core.array import Array
77
from larray.core.axis import Axis, AxisCollection
88
from larray.core.constants import nan
9-
from larray.util.misc import unique_list
109

1110

1211
def decode(s, encoding='utf-8', errors='strict'):
@@ -46,34 +45,51 @@ def index_to_labels(idx, sort=True):
4645
"""
4746
if isinstance(idx, pd.MultiIndex):
4847
if sort:
49-
return list(idx.levels)
48+
return list(idx.levels) # list of pd.Index
5049
else:
51-
return [unique_list(idx.get_level_values(label)) for label in range(idx.nlevels)]
50+
# requires Pandas >= 0.23 (and it does NOT sort the values)
51+
# TODO: unsure to_list is necessary (larray tests pass without it
52+
# but I am not sure this code path is covered by tests)
53+
# and there might be a subtle difference. The type
54+
# of the returned object without to_list() is pd.Index
55+
return [idx.unique(level).to_list() for level in range(idx.nlevels)]
5256
else:
5357
assert isinstance(idx, pd.Index)
5458
labels = list(idx.values)
5559
return [sorted(labels) if sort else labels]
5660

5761

58-
def cartesian_product_df(df, sort_rows=False, sort_columns=False, fill_value=nan, **kwargs):
59-
idx = df.index
60-
labels = index_to_labels(idx, sort=sort_rows)
62+
def product_index(idx, sort=False):
63+
"""
64+
Converts a pandas (Multi)Index to an (Multi)Index with a cartesian
65+
product of the labels present in each level
66+
"""
67+
labels = index_to_labels(idx, sort=sort)
6168
if isinstance(idx, pd.MultiIndex):
62-
if sort_rows:
63-
new_index = pd.MultiIndex.from_product(labels)
64-
else:
65-
new_index = pd.MultiIndex.from_tuples(list(product(*labels)))
69+
return pd.MultiIndex.from_product(labels), labels
6670
else:
67-
if sort_rows:
68-
new_index = pd.Index(labels[0], name=idx.name)
71+
assert isinstance(idx, pd.Index)
72+
if sort:
73+
return pd.Index(labels[0], name=idx.name), labels
6974
else:
70-
new_index = idx
71-
columns = sorted(df.columns) if sort_columns else list(df.columns)
72-
# the prodlen test is meant to avoid the more expensive array_equal test
73-
prodlen = np.prod([len(axis_labels) for axis_labels in labels])
74-
if prodlen == len(df) and columns == list(df.columns) and np.array_equal(idx.values, new_index.values):
75-
return df, labels
76-
return df.reindex(index=new_index, columns=columns, fill_value=fill_value, **kwargs), labels
75+
return idx, labels
76+
77+
78+
def cartesian_product_df(df, sort_rows=False, sort_columns=False,
79+
fill_value=nan, **kwargs):
80+
idx = df.index
81+
columns = df.columns
82+
prod_index, index_labels = product_index(idx, sort=sort_rows)
83+
prod_columns, column_labels = product_index(columns, sort=sort_columns)
84+
combined_labels = index_labels + column_labels
85+
# the len() tests are meant to avoid the more expensive array_equal tests
86+
if (len(prod_index) == len(idx) and
87+
len(prod_columns) == len(columns) and
88+
np.array_equal(idx.values, prod_index.values) and
89+
np.array_equal(columns.values, prod_columns.values)):
90+
return df, combined_labels
91+
return df.reindex(index=prod_index, columns=prod_columns,
92+
fill_value=fill_value, **kwargs), combined_labels
7793

7894

7995
def from_series(s, sort_rows=False, fill_value=nan, meta=None, **kwargs) -> Array:
@@ -124,8 +140,13 @@ def from_series(s, sort_rows=False, fill_value=nan, meta=None, **kwargs) -> Arra
124140
a1 b1 6.0 7.0
125141
"""
126142
if isinstance(s.index, pd.MultiIndex):
127-
# TODO: use argument sort=False when it will be available
128-
# (see https://github.com/pandas-dev/pandas/issues/15105)
143+
# Using unstack sort argument (requires Pandas >= 2.1) would make this
144+
# code simpler, but it makes it even slower than it already is.
145+
# As of Pandas 2.3.3 on 12/2025, a series with a large MultiIndex is
146+
# extremely slow to unstack, whether sort is used or not:
147+
# >>> arr = ndtest((200, 200, 200))
148+
# >>> s = arr.to_series() # 31.4 ms
149+
# >>> s.unstack(level=-1, fill_value=np.nan) # 1.5s !!!
129150
df = s.unstack(level=-1, fill_value=fill_value)
130151
# pandas (un)stack and pivot(_table) methods return a Dataframe/Series with sorted index and columns
131152
if not sort_rows:
@@ -211,13 +232,15 @@ def from_frame(df, sort_rows=False, sort_columns=False, parse_header=False, unfo
211232

212233
# handle 2 or more dimensions with the last axis name given using \
213234
if unfold_last_axis_name:
235+
# Note that having several axes in columns (and using df.columns.names)
236+
# in this case does not make sense
214237
if isinstance(axes_names[-1], str) and '\\' in axes_names[-1]:
215238
last_axes = [name.strip() for name in axes_names[-1].split('\\')]
216239
axes_names = axes_names[:-1] + last_axes
217240
else:
218241
axes_names += [None]
219242
else:
220-
axes_names += [df.columns.name]
243+
axes_names += df.columns.names
221244

222245
if cartesian_prod:
223246
df, axes_labels = cartesian_product_df(df, sort_rows=sort_rows, sort_columns=sort_columns,
@@ -226,12 +249,18 @@ def from_frame(df, sort_rows=False, sort_columns=False, parse_header=False, unfo
226249
if sort_rows or sort_columns:
227250
raise ValueError('sort_rows and sort_columns cannot not be used when cartesian_prod is set to False. '
228251
'Please call the method sort_labels on the returned array to sort rows or columns')
229-
axes_labels = index_to_labels(df.index, sort=False)
252+
index_labels = index_to_labels(df.index, sort=False)
253+
column_labels = index_to_labels(df.columns, sort=False)
254+
axes_labels = index_labels + column_labels
230255

231256
# Pandas treats column labels as column names (strings) so we need to convert them to values
232-
last_axis_labels = [parse(cell) for cell in df.columns.values] if parse_header else list(df.columns.values)
233-
axes_labels.append(last_axis_labels)
257+
if parse_header:
258+
ncolaxes = df.columns.nlevels
259+
for i in range(len(axes_labels) - ncolaxes, len(axes_labels)):
260+
axes_labels[i] = [parse(cell) for cell in axes_labels[i]]
234261

262+
# TODO: use zip(..., strict=True) instead when we drop support for Python 3.9
263+
assert len(axes_labels) == len(axes_names)
235264
axes = AxisCollection([Axis(labels, name) for labels, name in zip(axes_labels, axes_names)])
236265
data = df.values.reshape(axes.shape)
237266
return Array(data, axes, meta=meta)

larray/tests/test_array.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4121,6 +4121,7 @@ def test_to_frame():
41214121
assert df.columns.to_list() == ['c0', 'c1']
41224122
assert df.index.names == ['a', r'b\c']
41234123

4124+
41244125
def test_from_frame():
41254126
# 1) data = scalar
41264127
# ================
@@ -4530,6 +4531,81 @@ def test_from_frame():
45304531
res = from_frame(df, fill_value=-1)
45314532
assert_larray_equal(res, expected)
45324533

4534+
# 6) with a multi-index in columns
4535+
# ================================
4536+
4537+
# a) normal
4538+
arr = ndtest((2, 2, 2, 2))
4539+
df = arr.to_frame(ncolaxes=2)
4540+
res = from_frame(df)
4541+
assert_larray_equal(res, arr)
4542+
4543+
# b) with duplicated axis names
4544+
arr = ndtest("a=a0,a1;a=b0,b1;a=c0,c1;a=d0,d1")
4545+
df = arr.to_frame(ncolaxes=2)
4546+
res = from_frame(df)
4547+
assert_larray_equal(res, arr)
4548+
4549+
# c) with duplicated axes names and labels
4550+
arr = ndtest("a=a0,a1;a=a0,a1;a=a0,a1;a=a0,a1")
4551+
df = arr.to_frame(ncolaxes=2)
4552+
res = from_frame(df)
4553+
assert_larray_equal(res, arr)
4554+
4555+
# d) with unsorted labels
4556+
arr = ndtest("a=a1,a0;b=b1,b0;c=c1,c0;d=d1,d0")
4557+
df = arr.to_frame(ncolaxes=2)
4558+
res = from_frame(df)
4559+
assert_larray_equal(res, arr)
4560+
4561+
# e) with sorting of unsorted column labels
4562+
arr = ndtest("a=a1,a0;b=b1,b0;c=c1,c0;d=d1,d0")
4563+
df = arr.to_frame(ncolaxes=2)
4564+
expected = from_string(r"""
4565+
a b c\d d0 d1
4566+
a1 b1 c0 3 2
4567+
a1 b1 c1 1 0
4568+
a1 b0 c0 7 6
4569+
a1 b0 c1 5 4
4570+
a0 b1 c0 11 10
4571+
a0 b1 c1 9 8
4572+
a0 b0 c0 15 14
4573+
a0 b0 c1 13 12""")
4574+
res = from_frame(df, sort_columns=True)
4575+
assert_larray_equal(res, expected)
4576+
4577+
# f) with sorting of unsorted row labels
4578+
arr = ndtest("a=a1,a0;b=b1,b0;c=c1,c0;d=d1,d0")
4579+
df = arr.to_frame(ncolaxes=2)
4580+
expected = from_string(r"""
4581+
a b c\d d1 d0
4582+
a0 b0 c1 12 13
4583+
a0 b0 c0 14 15
4584+
a0 b1 c1 8 9
4585+
a0 b1 c0 10 11
4586+
a1 b0 c1 4 5
4587+
a1 b0 c0 6 7
4588+
a1 b1 c1 0 1
4589+
a1 b1 c0 2 3""")
4590+
res = from_frame(df, sort_rows=True)
4591+
assert_larray_equal(res, expected)
4592+
4593+
# g) with sorting of all unsorted labels
4594+
arr = ndtest("a=a1,a0;b=b1,b0;c=c1,c0;d=d1,d0")
4595+
df = arr.to_frame(ncolaxes=2)
4596+
expected = from_string(r"""
4597+
a b c\d d0 d1
4598+
a0 b0 c0 15 14
4599+
a0 b0 c1 13 12
4600+
a0 b1 c0 11 10
4601+
a0 b1 c1 9 8
4602+
a1 b0 c0 7 6
4603+
a1 b0 c1 5 4
4604+
a1 b1 c0 3 2
4605+
a1 b1 c1 1 0""")
4606+
res = from_frame(df, sort_rows=True, sort_columns=True)
4607+
assert_larray_equal(res, expected)
4608+
45334609

45344610
def test_asarray():
45354611
series = pd.Series([0, 1, 2], ['a0', 'a1', 'a2'], name='a')

0 commit comments

Comments
 (0)