Skip to content

Commit 57e5a3d

Browse files
committed
CLN: implemented check_compatible argument in broadcast_with
1 parent a46e8c1 commit 57e5a3d

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

larray/core/array.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,9 +1977,7 @@ def __setitem__(self, key, value, collapse_slices=True, translate_key=True, poin
19771977
raw_broadcasted_key, target_axes, _ = \
19781978
self.axes._key_to_raw_and_axes(key, collapse_slices, translate_key, points, wildcard=True)
19791979
if isinstance(value, Array):
1980-
# TODO: the check_compatible should be included in broadcast_with
1981-
value = value.broadcast_with(target_axes)
1982-
value.axes.check_compatible(target_axes)
1980+
value = value.broadcast_with(target_axes, check_compatible=True)
19831981

19841982
# replace incomprehensible error message "could not broadcast input array from shape XX into shape YY"
19851983
# for users by "incompatible axes"
@@ -2084,7 +2082,7 @@ def reshape(self, target_axes):
20842082
# 4, 3, 2 -> 2, 2, 3, 2 is potentially ok (splitting dim)
20852083
if not isinstance(target_axes, AxisCollection):
20862084
target_axes = AxisCollection(target_axes)
2087-
data = np.asarray(self).reshape(target_axes.shape)
2085+
data = self.data.reshape(target_axes.shape)
20882086
return Array(data, target_axes)
20892087

20902088
# TODO: this should be a private method
@@ -2114,7 +2112,7 @@ def reshape_like(self, target):
21142112
"""
21152113
return self.reshape(target.axes)
21162114

2117-
def broadcast_with(self, target):
2115+
def broadcast_with(self, target, check_compatible=False):
21182116
r"""
21192117
Returns an array that is (NumPy) broadcastable with target.
21202118
@@ -2131,6 +2129,9 @@ def broadcast_with(self, target):
21312129
----------
21322130
target : Array or collection of Axis
21332131
2132+
check_compatible : bool, optional
2133+
Whether or not to check that common axes are compatible. Defaults to False.
2134+
21342135
Returns
21352136
-------
21362137
Array
@@ -2144,6 +2145,7 @@ def broadcast_with(self, target):
21442145
if self.axes == target_axes:
21452146
return self
21462147
# determine real target order (= left_only then target_axes)
2148+
# (we will add length one axes to the left like numpy just below)
21472149
target_axes = (self.axes - target_axes) | target_axes
21482150

21492151
# XXX: this breaks la['1,5,9'] = la['2,7,3']
@@ -2154,7 +2156,10 @@ def broadcast_with(self, target):
21542156
array = self.transpose(target_axes & self.axes)
21552157

21562158
# 2) add length one axes
2157-
return array.reshape(array.axes.get_all(target_axes))
2159+
res_axes = array.axes.get_all(target_axes)
2160+
if check_compatible:
2161+
res_axes.check_compatible(target_axes)
2162+
return array.reshape(res_axes)
21582163

21592164
# XXX: I wonder if effectively dropping the labels is necessary or not
21602165
# we could perfectly only mark the axis as being a wildcard axis and keep

0 commit comments

Comments
 (0)