diff --git a/qse/qbits.py b/qse/qbits.py index e464d8d..e223410 100644 --- a/qse/qbits.py +++ b/qse/qbits.py @@ -693,93 +693,69 @@ def __iter__(self): for i in range(len(self)): yield self[i] - def __getitem__(self, i): - """Return a subset of the qbits. - - i -- scalar integer, list of integers, or slice object - describing which qbits to return. - - If i is a scalar, return an Qbit object. If i is a list or a - slice, return an Qbits object with the same cell, pbc, and - other associated info as the original Qbits object. The - indices of the constraints will be shuffled so that they match - the indexing in the subset returned. + def __getitem__(self, indices): + """ + Return a subset of the qbits. + Parameters + ---------- + indices : int | list | slice + The indices to be returned. + + Returns + ------- + Qbit | Qbits. + If indices is a scalar a Qbit object is returned. If indices + is a list or a slice, a Qbits object with the same cell, pbc, and + other associated info as the original Qbits object is returned. """ - if isinstance(i, numbers.Integral): + if isinstance(indices, numbers.Integral): nqbits = len(self) - if i < -nqbits or i >= nqbits: + if indices < -nqbits or indices >= nqbits: raise IndexError("Index out of range.") + return Qbit(qbits=self, index=indices) - return Qbit(qbits=self, index=i) - elif not isinstance(i, slice): - i = np.array(i) - # if i is a mask - if i.dtype == bool: - if len(i) != len(self): + if not isinstance(indices, slice): + indices = np.array(indices) + # if indices is a mask. + if indices.dtype == bool: + if len(indices) != len(self): raise IndexError( - "Length of mask {} must equal " - "number of qbits {}".format(len(i), len(self)) + f"Length of mask {len(indices)} must equal " + f"number of qbits {len(self)}" ) - i = np.arange(len(self))[i] - - import copy - - conadd = [] - # Constraints need to be deepcopied, but only the relevant ones. - for con in copy.deepcopy(self.constraints): - try: - con.index_shuffle(self, i) - except (IndexError, NotImplementedError): - pass - else: - conadd.append(con) + indices = np.arange(len(self))[indices] qbits = self.__class__( cell=self.cell, pbc=self.pbc, info=self.info, - # should be communicated to the slice as well celldisp=self._celldisp, ) - # TODO: Do we need to shuffle indices in adsorbate_info too? qbits.arrays = {} for name, a in self.arrays.items(): - qbits.arrays[name] = a[i].copy() + qbits.arrays[name] = a[indices].copy() - qbits.constraints = conadd return qbits - def __delitem__(self, i): - from qse.constraints import FixQbits - - for c in self._constraints: - if not isinstance(c, FixQbits): - raise RuntimeError( - "Remove constraint using set_constraint() " "before deleting qbits." - ) + def __delitem__(self, indices): + """ + Delete a subset of the qbits. - if isinstance(i, list) and len(i) > 0: + Parameters + ---------- + indices : int | list + The indices to be deleted. + """ + if isinstance(indices, list) and len(indices) > 0: # Make sure a list of booleans will work correctly and not be # interpreted at 0 and 1 indices. - i = np.array(i) - - if len(self._constraints) > 0: - n = len(self) - i = np.arange(n)[i] - if isinstance(i, int): - i = [i] - constraints = [] - for c in self._constraints: - c = c.delete_qbits(i, n) - if c is not None: - constraints.append(c) - self.constraints = constraints + indices = np.array(indices) mask = np.ones(len(self), bool) - mask[i] = False + mask[indices] = False for name, a in self.arrays.items(): self.arrays[name] = a[mask] @@ -824,8 +800,6 @@ def __imul__(self, m): def draw(self, ax=None, radius=None): _draw(self, ax=ax, radius=radius) - # - def repeat(self, rep): """Create new repeated qbits object. diff --git a/tests/qbits_test.py b/tests/qbits_test.py index 0cba5d6..7a30bd0 100644 --- a/tests/qbits_test.py +++ b/tests/qbits_test.py @@ -127,3 +127,43 @@ def test_translate(nqbits, type_of_disp): else: qbits.translate(disp) assert np.allclose(qbits.get_positions(), positions + disp) + + +def test_get_item(): + """Test __getitem__""" + positions = np.random.rand(4, 3) + qbits = qse.Qbits(positions=positions) + + # test int + for indices in [0, 2]: + assert isinstance(qbits[indices], qse.Qbit) + assert np.allclose(qbits[indices].position, positions[indices]) + + # test list + for indices in [[0, 2], [1, 3, 2]]: + assert isinstance(qbits[indices], qse.Qbits) + assert np.allclose(qbits[indices].get_positions(), positions[indices]) + + # test slice + assert isinstance(qbits[1:3], qse.Qbits) + assert np.allclose(qbits[1:3].get_positions(), positions[1:3]) + + +@pytest.mark.parametrize("indices", [1, [0, 1, 3]]) +def test_del_item(indices): + """Test __delitem__""" + nqbits = 4 + positions = np.random.rand(nqbits, 3) + qbits = qse.Qbits(positions=positions) + + del qbits[indices] + + if isinstance(indices, int): + indices = [indices] + + assert isinstance(qbits, qse.Qbits) + assert len(qbits) == nqbits - len(indices) + assert np.allclose( + qbits.get_positions(), + positions[[i for i in range(nqbits) if i not in indices]], + )