44
55import torch
66from torchmdnet .models .model import load_model
7+ import warnings
78
89# dict of preset transforms
910tranforms = {
@@ -42,6 +43,10 @@ class External:
4243 Whether to use CUDA graphs to speed up the calculation. Default: False
4344 cuda_graph_warmup_steps : int, optional
4445 Number of steps to run as warmup before recording the CUDA graph. Default: 12
46+ dtype : torch.dtype or str, optional
47+ Cast the input to this dtype if defined. If passed as a string it should be a valid torch dtype. Default: torch.float32
48+ kwargs : dict, optional
49+ Extra arguments to pass to the model when loading it.
4550 """
4651
4752 def __init__ (
@@ -52,10 +57,28 @@ def __init__(
5257 output_transform = None ,
5358 use_cuda_graph = False ,
5459 cuda_graph_warmup_steps = 12 ,
60+ dtype = torch .float32 ,
61+ ** kwargs ,
5562 ):
5663 if isinstance (netfile , str ):
57- self .model = load_model (netfile , device = device , derivative = True )
64+ extra_args = kwargs
65+ if use_cuda_graph :
66+ warnings .warn (
67+ "CUDA graphs are enabled, setting static_shapes=True and check_errors=False"
68+ )
69+ extra_args ["static_shapes" ] = True
70+ extra_args ["check_errors" ] = False
71+ self .model = load_model (
72+ netfile ,
73+ device = device ,
74+ derivative = True ,
75+ ** extra_args ,
76+ )
5877 elif isinstance (netfile , torch .nn .Module ):
78+ if kwargs :
79+ warnings .warn (
80+ "Warning: extra arguments are being ignored when passing a torch.nn.Module"
81+ )
5982 self .model = netfile
6083 else :
6184 raise ValueError (
@@ -87,6 +110,12 @@ def __init__(
87110 self .forces = None
88111 self .box = None
89112 self .pos = None
113+ if isinstance (dtype , str ):
114+ try :
115+ dtype = getattr (torch , dtype )
116+ except AttributeError :
117+ raise ValueError (f"Unknown torch dtype { dtype } " )
118+ self .dtype = dtype
90119
91120 def _init_cuda_graph (self ):
92121 stream = torch .cuda .Stream ()
@@ -101,7 +130,7 @@ def _init_cuda_graph(self):
101130 self .embeddings , self .pos , self .batch , self .box
102131 )
103132
104- def calculate (self , pos , box = None ):
133+ def calculate (self , pos , box = None ):
105134 """Calculate the energy and forces of the system.
106135
107136 Parameters
@@ -118,7 +147,9 @@ def calculate(self, pos, box = None):
118147 forces : torch.Tensor
119148 Forces on the atoms in the system.
120149 """
121- pos = pos .to (self .device ).type (torch .float32 ).reshape (- 1 , 3 )
150+ pos = pos .to (self .device ).to (self .dtype ).reshape (- 1 , 3 )
151+ if box is not None :
152+ box = box .to (self .device ).to (self .dtype )
122153 if self .use_cuda_graph :
123154 if self .pos is None :
124155 self .pos = (
@@ -128,10 +159,12 @@ def calculate(self, pos, box = None):
128159 .requires_grad_ (pos .requires_grad )
129160 )
130161 if self .box is None and box is not None :
131- self .box = box .clone ().to (self .device ).detach ()
162+ self .box = box .clone ().to (self .device ).to ( self . dtype ). detach ()
132163 if self .cuda_graph is None :
133164 self ._init_cuda_graph ()
134- assert self .cuda_graph is not None , "CUDA graph is not initialized. This should not had happened."
165+ assert (
166+ self .cuda_graph is not None
167+ ), "CUDA graph is not initialized. This should not had happened."
135168 with torch .no_grad ():
136169 self .pos .copy_ (pos )
137170 if box is not None :
0 commit comments