Skip to content

Commit 0269c04

Browse files
committed
Merge branch 'indexing' of https://github.com/edwinsolisf/arrayfire-py into afwheel310
2 parents ba78805 + 0b3f859 commit 0269c04

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

arrayfire/array_object.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,14 @@ def __getitem__(self, key: IndexKey, /) -> Array:
780780
indexing = tuple(key_list)
781781

782782
out._arr = wrapper.index_gen(self._arr, ndims, wrapper.get_indices(indexing)) # type: ignore[arg-type]
783+
784+
if isinstance(key, Array) and key.is_bool:
785+
wrapper.release_array(indexing)
786+
elif isinstance(key, tuple):
787+
for i in range(len(key)):
788+
if isinstance(key[i], Array) and key[i].is_bool:
789+
wrapper.release_array(indexing[i])
790+
783791
return out
784792

785793
def __index__(self) -> int:
@@ -807,7 +815,6 @@ def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> No
807815
808816
"""
809817
ndims = self.ndim
810-
811818
is_array_with_bool = isinstance(key, Array) and type(key) is afbool
812819

813820
if is_array_with_bool:
@@ -842,20 +849,27 @@ def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> No
842849
for elem in key:
843850
if isinstance(elem, Array):
844851
if elem.is_bool:
845-
key_list.append(wrapper.where(elem.arr))
852+
locs = wrapper.where(elem.arr)
853+
key_list.append(locs)
846854
else:
847855
key_list.append(elem.arr)
848856
else:
849857
key_list.append(elem)
850858
indexing = tuple(key_list)
851859

852-
indices = wrapper.get_indices(indexing)
860+
out = wrapper.assign_gen(self._arr, other_arr, ndims, wrapper.get_indices(indexing))
853861

854-
out = wrapper.assign_gen(self._arr, other_arr, ndims, indices)
862+
if isinstance(key, Array) and key.is_bool:
863+
wrapper.release_array(indexing)
864+
elif isinstance(key, tuple):
865+
for i in range(len(key)):
866+
if isinstance(key[i], Array) and key[i].is_bool:
867+
wrapper.release_array(indexing[i])
855868

856869
wrapper.release_array(self._arr)
857870
if del_other:
858871
wrapper.release_array(other_arr)
872+
859873
self._arr = out
860874

861875
def __str__(self) -> str:

0 commit comments

Comments
 (0)