Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def set_torch_cuda_arch_list():
set_torch_cuda_arch_list()

extension_root= os.path.join("torchmdnet", "extensions")
neighbor_sources=["neighbors_cpu.cpp"]
neighbor_sources=["neighbors_cpu.cpp", "neighbors_backward.cpp"]
if use_cuda:
neighbor_sources.append("neighbors_cuda.cu")
neighbor_sources = [os.path.join(extension_root, "neighbors", source) for source in neighbor_sources]
Expand Down
25 changes: 20 additions & 5 deletions tests/test_tensorforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from utils import load_example_args, create_example_batch
import random
import numpy as np
torch.autograd.set_detect_anomaly(True)

def set_all_seeds(seed):
torch.manual_seed(seed)
Expand All @@ -20,8 +21,9 @@ def test_tensorforce():
args["manual_grad"] = False
args["derivative"] = True
args["output_model"] = "Scalar"
args["seed"] = 1234
args["seed"] = 12345
args["num_layers"] = 0
args["precision"] = 32
set_all_seeds(args["seed"])
model_truth = create_model(args)

Expand All @@ -35,18 +37,31 @@ def test_tensorforce():
n_atoms = 200
z, pos, batch = create_example_batch(n_atoms=n_atoms)

if args["precision"] == 64:
pos = pos.to(torch.float64)
pos.requires_grad = True
energy, forces = model_truth(z, pos, batch)
(forces + energy.sum()).sum().backward()
force_diff = pos.grad.clone().detach()
pos.grad = None
del model_truth
energy_test, forces_test = model_test(z, pos, batch)
(forces_test+energy_test.sum()).sum().backward()
force_diff_test = pos.grad.clone().detach()
del model_test
enrgy_rel_error = torch.abs(energy - energy_test) / torch.abs(energy)
forces_rel_error = torch.abs(forces - forces_test) / torch.abs(forces)
force_diff_rel_error = torch.abs(force_diff - force_diff_test) / torch.abs(force_diff)

print("Max energy relative error: ", torch.max(enrgy_rel_error))
print("Max forces relative error: ", torch.max(forces_rel_error))
print("Max forces relative error in X direction: ", torch.max(forces_rel_error[:, 0]))
print("Max forces relative error in Y direction: ", torch.max(forces_rel_error[:, 1]))
print("Max forces relative error in Z direction: ", torch.max(forces_rel_error[:, 2]))
print("Max force diff relative error: ", torch.max(force_diff_rel_error))
print("Mean energy relative error: ", torch.mean(enrgy_rel_error))
print("Mean forces relative error: ", torch.mean(forces_rel_error))
print("Mean force diff relative error: ", torch.mean(force_diff_rel_error))


torch.testing.assert_allclose(energy, energy_test)
torch.testing.assert_allclose(forces, forces_test)
torch.testing.assert_close(energy, energy_test)
torch.testing.assert_close(forces, forces_test)
torch.testing.assert_close(force_diff, force_diff_test, rtol=1e-3, atol=1e-3)
174 changes: 156 additions & 18 deletions torchmdnet/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# The extensions will be available under torch.ops.torchmdnet_extensions, but you can add wrappers here to make them more convenient to use.
import os.path as osp
import torch
from torch import Tensor
import importlib.machinery
from typing import Tuple
from typing import Tuple, Optional, List


def _load_library(library):
Expand Down Expand Up @@ -39,18 +40,18 @@ def is_current_stream_capturing():
return _is_current_stream_capturing()


def get_neighbor_pairs_kernel(
def neighbor_forward(
strategy: str,
positions: torch.Tensor,
batch: torch.Tensor,
box_vectors: torch.Tensor,
positions: Tensor,
batch: Tensor,
box_vectors: Tensor,
use_periodic: bool,
cutoff_lower: float,
cutoff_upper: float,
max_num_pairs: int,
loop: bool,
include_transpose: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Computes the neighbor pairs for a given set of atomic positions.

The list is generated as a list of pairs (i,j) without any enforced ordering.
Expand All @@ -60,11 +61,11 @@ def get_neighbor_pairs_kernel(
----------
strategy : str
Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`.
positions : torch.Tensor
positions : Tensor
A tensor with shape (N, 3) representing the atomic positions.
batch : torch.Tensor
batch : Tensor
A tensor with shape (N,). Specifies the batch for each atom.
box_vectors : torch.Tensor
box_vectors : Tensor
The vectors defining the periodic box with shape `(3, 3)`.
use_periodic : bool
Whether to apply periodic boundary conditions.
Expand All @@ -81,18 +82,18 @@ def get_neighbor_pairs_kernel(

Returns
-------
neighbors : torch.Tensor
neighbors : Tensor
List of neighbors for each atom. Shape (2, max_num_pairs).
distances : torch.Tensor
distances : Tensor
List of distances for each atom. Shape (max_num_pairs,).
distance_vecs : torch.Tensor
distance_vecs : Tensor
List of distance vectors for each atom. Shape (max_num_pairs, 3).
num_pairs : torch.Tensor
num_pairs : Tensor
The number of pairs found.

Notes
-----
This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`.
- This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`.
"""
return torch.ops.torchmdnet_extensions.get_neighbor_pairs(
strategy,
Expand All @@ -108,8 +109,145 @@ def get_neighbor_pairs_kernel(
)


# For some unknown reason torch.compile is not able to compile this function
if int(torch.__version__.split(".")[0]) >= 2:
import torch._dynamo as dynamo
def neighbor_backward(
edge_index: Tensor,
edge_vec: Tensor,
edge_weight: Tensor,
num_atoms: int,
grad_edge_vec: Optional[Tensor] = None,
grad_edge_weight: Optional[Tensor] = None,
) -> Tensor:
"""Computes the neighbor pairs for a given set of atomic positions. This is the backwards pass of the :any:`get_neighbor_pairs_kernel` function.

Parameters
----------
edge_index : Tensor
A tensor with shape (2, max_num_pairs) representing the neighbor pairs.
edge_vec : Tensor
A tensor with shape (max_num_pairs, 3) representing the distance vectors.
edge_weight : Tensor
A tensor with shape (max_num_pairs,) representing the distances.
num_atoms : int
The number of atoms.
grad_edge_vec : Tensor, optional
The gradient of the distance vectors. If None, the gradient is assumed to be 1.
grad_edge_weight : Tensor, optional
The gradient of the distances. If None, the gradient is assumed to be 1.

Returns
-------
grad_positions : Tensor
The gradient of the positions. Shape (N, 3).
"""
if grad_edge_vec is None:
grad_edge_vec = torch.ones_like(edge_vec)
if grad_edge_weight is None:
grad_edge_weight = torch.ones_like(edge_weight)
return torch.ops.torchmdnet_extensions.get_neighbor_pairs_backward(
edge_index, edge_vec, edge_weight, grad_edge_vec, grad_edge_weight, num_atoms
)

dynamo.disallow_in_graph(get_neighbor_pairs_kernel)

# # For some unknown reason torch.compile is not able to compile this function
# if int(torch.__version__.split(".")[0]) >= 2:
# import torch._dynamo as dynamo

# dynamo.disallow_in_graph(get_neighbor_pairs_kernel)


# This class is a PyTorch autograd extension for computing neighbor pairs and their gradients.
class get_neighbor_pairs(torch.autograd.Function):
@staticmethod
def forward(
ctx,
strategy: str,
positions: Tensor,
batch: Tensor,
box_vectors: Tensor,
use_periodic: bool,
cutoff_lower: float,
cutoff_upper: float,
max_num_pairs: int,
loop: bool,
include_transpose: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# Call the forward kernel and store the results
neighbors, deltas, distances, i_curr_pair = neighbor_forward(
strategy,
positions,
batch,
box_vectors,
use_periodic,
cutoff_lower,
cutoff_upper,
max_num_pairs,
loop,
include_transpose,
)
# Save tensors for backward computation
ctx.save_for_backward(neighbors, deltas, distances)
ctx.num_atoms = positions.size(0)
return (neighbors, deltas, distances, i_curr_pair)

@staticmethod
def backward(
ctx, *grad_outputs: Tensor
) -> Tuple[
Optional[Tensor],
Tensor,
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
]:
neighbors, deltas, distances = ctx.saved_tensors
num_atoms = ctx.num_atoms
grad_edge_vec = grad_outputs[1]
grad_edge_weight = grad_outputs[2]
grad_positions = neighbor_backward(
neighbors, deltas, distances, num_atoms, grad_edge_vec, grad_edge_weight
)
return (
None,
grad_positions,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)


def get_neighbor_pairs_kernel(
strategy: str,
positions: Tensor,
batch: Tensor,
box_vectors: Tensor,
use_periodic: bool,
cutoff_lower: float,
cutoff_upper: float,
max_num_pairs: int,
loop: bool,
include_transpose: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return get_neighbor_pairs.apply(
strategy,
positions,
batch,
box_vectors,
use_periodic,
cutoff_lower,
cutoff_upper,
max_num_pairs,
loop,
include_transpose,
)
33 changes: 18 additions & 15 deletions torchmdnet/extensions/extensions.cpp
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
/* Raul P. Pelaez 2023. Torch extensions to the torchmdnet library.
* You can expose functions to python here which will be compatible with TorchScript.
* Add your exports to the TORCH_LIBRARY macro below, see __init__.py to see how to access them from python.
* The WITH_CUDA macro will be defined when compiling with CUDA support.
* Add your exports to the TORCH_LIBRARY macro below, see __init__.py to see how to access them from
* python. The WITH_CUDA macro will be defined when compiling with CUDA support.
*/


#include <torch/extension.h>
#if defined(WITH_CUDA)
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime_api.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime_api.h>
#endif
#include "neighbors/neighbors_backwards.h"

/* @brief Returns true if the current torch CUDA stream is capturing.
* This function is required because the one available in torch is not compatible with TorchScript.
* @return True if the current torch CUDA stream is capturing.
*/
bool is_current_stream_capturing() {
#if defined(WITH_CUDA)
auto current_stream = at::cuda::getCurrentCUDAStream().stream();
cudaStreamCaptureStatus capture_status;
cudaError_t err = cudaStreamGetCaptureInfo(current_stream, &capture_status, nullptr);
if (err != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(err));
}
return capture_status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive;
auto current_stream = at::cuda::getCurrentCUDAStream().stream();
cudaStreamCaptureStatus capture_status;
cudaError_t err = cudaStreamGetCaptureInfo(current_stream, &capture_status, nullptr);
if (err != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(err));
}
return capture_status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive;
#else
return false;
return false;
#endif
}


TORCH_LIBRARY(torchmdnet_extensions, m) {
m.def("is_current_stream_capturing", is_current_stream_capturing);
m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, "
"bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool "
"loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor "
"distance_vecs, Tensor num_pairs)");
m.def("get_neighbor_pairs_backward", neighbors_backward);
}
26 changes: 26 additions & 0 deletions torchmdnet/extensions/neighbors/neighbors_backward.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "neighbors_backwards.h"

using torch::indexing::Slice;

Tensor neighbors_backward(const Tensor& edge_index, const Tensor& edge_vec,
const Tensor& edge_weight, const Tensor& grad_edge_vec,
const Tensor& grad_edge_weight, int64_t num_atoms) {
auto zero_mask = edge_weight == 0;
auto zero_mask3 = zero_mask.unsqueeze(-1).expand_as(grad_edge_vec);
// We need to avoid dividing by 0. Otherwise Autograd fills the gradient with NaNs in the
// case of a double backwards. This is why we index_select like this.
auto grad_distances_ = edge_vec / edge_weight.masked_fill(zero_mask, 1).unsqueeze(-1) *
grad_edge_weight.masked_fill(zero_mask, 0).unsqueeze(-1);
auto result = grad_edge_vec.masked_fill(zero_mask3, 0) + grad_distances_;
// Given that there is no masked_index_add function, in order to make the operation
// CUDA-graph compatible I need to transform masked indices into a dummy value (num_atoms)
// and then exclude that value from the output.
// TODO: replace this once masked_index_add or masked_scatter_add are available
auto grad_positions_ = torch::zeros({num_atoms + 1, 3}, edge_vec.options());
auto edge_index_ =
edge_index.masked_fill(zero_mask.unsqueeze(0).expand_as(edge_index), num_atoms);
grad_positions_.index_add_(0, edge_index_[0], result);
grad_positions_.index_add_(0, edge_index_[1], -result);
auto grad_positions = grad_positions_.index({Slice(0, num_atoms), Slice()});
return grad_positions;
}
6 changes: 6 additions & 0 deletions torchmdnet/extensions/neighbors/neighbors_backwards.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <torch/extension.h>
using torch::Tensor;

Tensor neighbors_backward(const Tensor& edge_index, const Tensor& edge_vec,
const Tensor& edge_weight, const Tensor& grad_edge_vec,
const Tensor& grad_edge_weight, int64_t num_atoms);
3 changes: 1 addition & 2 deletions torchmdnet/extensions/neighbors/neighbors_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using std::tuple;
using torch::arange;
using torch::div;
using torch::frobenius_norm;
using torch::full;
using torch::hstack;
using torch::index_select;
Expand Down Expand Up @@ -68,7 +67,7 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
deltas -= outer(round(deltas.index({Slice(), 0}) / box_vectors.index({0, 0})),
box_vectors.index({0}));
}
distances = frobenius_norm(deltas, 1);
distances = torch::linalg::vector_norm(deltas, 2.0, 1, false, c10::nullopt);
mask = (distances < cutoff_upper) * (distances >= cutoff_lower);
neighbors = neighbors.index({Slice(), mask});
deltas = deltas.index({mask, Slice()});
Expand Down
Loading