Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions dscribe/descriptors/acsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,17 @@ def validate_derivatives_method(self, method, attach):
raise ValueError(
"ACSF derivatives can only be calculated with attach=True."
)
return super().validate_derivatives_method(method, attach)

if method == "auto":
method = "analytical"
return method
elif method == "analytical":
return method
elif method == "numerical":
return super().validate_derivatives_method(method, attach)
else:
raise ValueError(
"%s derivative method not implemented in ACSF derivatives"%method
)
@property
def species(self):
return self._species
Expand Down Expand Up @@ -396,3 +405,77 @@ def g5_params(self):
@g5_params.setter
def g5_params(self, value):
self.acsf_wrapper.g5_params = self.validate_g5_params(value)


def derivatives_analytical(
self,
d,
c,
system,
centers,
indices,
attach,
return_descriptor=True,
):
"""Return the analytical derivatives for the given system.
Args:
system (:class:`ase.Atoms`): Atomic structure.
indices (list): Indices of atoms for which the derivatives will be computed.
return_descriptor (bool): Whether to also calculate the descriptor
in the same function call. This is true by default as it
typically is faster to calculate both in one go.
Returns:
If return_descriptor is True, returns a tuple, where the first item
is the derivative array and the second is the descriptor array.
Otherwise only returns the derivatives array. The derivatives array
is a 3D numpy array. The dimensions are: [n_atoms, 3, n_features].
The first dimension goes over the included atoms. The order is same
as the order of atoms in the given system. The second dimension
goes over the cartesian components, x, y and z. The last dimension
goes over the features in the default order.
"""

# Validate and normalize system
positions = self.validate_positions(system.get_positions())
atomic_numbers = self.validate_atomic_numbers(system.get_atomic_numbers())
pbc = self.validate_pbc(system.get_pbc())
cell = self.validate_cell(system.get_cell(), pbc)

# Create C-compatible list of atomic indices for which the ACSF is
# calculated
if centers is None:
centers = np.arange(len(system), dtype=np.int32)
else:
centers = np.asarray(centers, dtype=np.int32)
"""
if desc_centers is None:
desc_centers = np.arange(len(system), dtype=np.int32)
else:
desc_centers = np.asarray(desc_centers, dtype=np.int32)

if grad_centers is None:
grad_centers = np.arange(len(system), dtype=np.int32)
else:
grad_centers = np.asarray(grad_centers, dtype=np.int32)
"""
# --- CRUCIAL CHANGE HERE: Construct the CellList object in Python ---
# This assumes that CellList is exposed in your pybind11 wrapper as self.acsf_wrapper.CellList

# Call the C++ function with the correct arguments and order
self.acsf_wrapper.derivatives_analytical(
d, # 1st arg: py::array_t<double> derivatives
c, # 2nd arg: py::array_t<double> descriptor
atomic_numbers, # 3rd arg: py::array_t<int> desc_centers
cell, # 4th arg: py::array_t<double> desc_centers
pbc, # 5th arg: py::array_t<int> desc_centers
positions, # 6th arg: py::array_t<double> atomic_positions
centers, # 7th arg: py::array_t<int> desc_centers
centers, # 8th arg: py::array_t<int> grad_centers
return_descriptor # 8th arg: const bool return_descriptor
)

if return_descriptor:
return d, c
else:
return d

Loading