I was trying to jit compile on version 0.9.1 from scratch and I kept running into numerous errors regarding it. Here is the full python script:
class Equivariance:
'''All functions within this is used entirely for the equivariance class. '''
@staticmethod
def compute_radial_geometry(pos: torch.Tensor, edge_index: torch.Tensor, cutoff: float, edge_attributes: torch.Tensor):
row, col = edge_index[0], edge_index[1]
edge_vec = pos[row] - pos[col]
d_scaled = edge_attributes / cutoff
p = 6.0
env = 1.0 - (p+1)*(p+2)/2 * torch.pow(d_scaled, p) + \
p*(p+2) * torch.pow(d_scaled, p+1) - \
p*(p+1)/2 * torch.pow(d_scaled, p+2)
mask = d_scaled < 1.0
env = torch.where(mask, env, torch.zeros_like(env))
return edge_vec, env, d_scaled
class OptimizedRadial(nn.Module):
def __init__(self, num_kernels, hidden, out_dim):
super().__init__()
self.register_buffer("frequencies", torch.pi * torch.arange(1, num_kernels+1))
self.mlp = nn.Sequential(
nn.Linear(num_kernels, hidden),
nn.SiLU(),
nn.Linear(hidden, out_dim),
)
def forward(self, d_scaled, env):
basis = env * torch.sin(d_scaled * self.frequencies.view(1, -1))
return self.mlp(basis)
class EquivariantLayer(nn.Module):
def __init__(self, irreps_in="8x0e", irreps_pre = None, irreps_hidden="128x0e + 128x1o", irreps_out="2x0e", lmax=1, nu=2, device='cuda', radius = 1.0, use_pre_linear=False):
super().__init__()
self.radius = radius
self.device = device
self.num_elements = 5
self.irreps_in = cue.Irreps(cue.O3, irreps_in)
self.irreps_hidden = cue.Irreps(cue.O3, irreps_hidden)
if use_pre_linear:
assert irreps_pre is not None, "irreps_pre must be provided if use_pre_linear=True"
self.ireps_pre = cue.Irreps(cue.O3, irreps_pre)
else:
self.ireps_pre = self.irreps_in
self.irreps_out = cue.Irreps(cue.O3, irreps_out)
self.irreps_sh = cue.Irreps("O3", " + ".join([f"1x{l}{'e' if l%2==0 else 'o'}" for l in range(lmax + 1)]))
self.spherical = cuet.SphericalHarmonics(list(range(lmax + 1)), normalize=True, device=device)
if use_pre_linear:
self.pre_linear = cuet.Linear(
self.irreps_in,
self.ireps_pre,
layout=cue.ir_mul,
internal_weights=True,
weight_classes=self.num_elements,
device=device,
method='indexed_linear' if device=="cuda" else "naive"
)
else:
self.ireps_pre = cue.Irreps(cue.O3, irreps_in)
self.pre_linear = None
self.tp = cuet.ChannelWiseTensorProduct(self.ireps_pre,
self.irreps_sh,
self.irreps_sh,
layout=cue.ir_mul,
device=device,
shared_weights=False,
internal_weights=False
)
self.tp_to_hidden = cuet.Linear(
self.tp.irreps_out,
self.irreps_hidden,
layout=cue.ir_mul,
internal_weights=True,
weight_classes=self.num_elements,
device=device,
method='indexed_linear' if device=="cuda" else "naive"
)
self.radial = Equivariance.OptimizedRadial(num_kernels=20,
hidden=32,
out_dim=self.tp.weight_numel
).to(device)
self.sym_cont = cuet.SymmetricContraction(self.irreps_hidden,
self.irreps_hidden,
contraction_degree=nu,
layout_in=cue.ir_mul,
layout_out=cue.ir_mul,
device=device,
original_mace=True,
num_elements=self.num_elements,
dtype=torch.float32)
self.linear = cuet.Linear(self.sym_cont.irreps_out,
self.irreps_out,
layout=cue.ir_mul,
internal_weights=True,
weight_classes=self.num_elements,
device=device,
method='indexed_linear' if device=="cuda" else "naive")
self.to_cue = cuet.TransposeIrrepsLayout(self.irreps_in,
source=cue.mul_ir,
target=cue.ir_mul,
device=device)
self.from_cue = cuet.TransposeIrrepsLayout(self.irreps_out,
source=cue.ir_mul,
target=cue.mul_ir,
device=device)
self.to_cue_sh = cuet.TransposeIrrepsLayout(self.irreps_sh, source=cue.mul_ir, target=cue.ir_mul, device=device)
def forward(self, x, edge_index, pos, edge_attributes, edge_weights, atom_class):
# MAKE SURE THAT INPUT IS SORTED BY ATOM_CLASS
assert atom_class.max().item() < 5
edge_vec, env, d_scaled = Equivariance.compute_radial_geometry(
pos,
edge_index,
cutoff=self.radius,
edge_attributes=edge_attributes
)
edge_spherical = self.spherical(edge_vec)
tp_weights = self.radial(d_scaled, env) * edge_weights.view(-1, 1)
x_cue = self.to_cue(x).contiguous()
if self.pre_linear:
x_cue = self.pre_linear(x_cue, weight_indices=atom_class)
edge_spherical_cue = self.to_cue_sh(edge_spherical)
m_cue = self.tp(x_cue,
edge_spherical_cue,
tp_weights,
indices_1=edge_index[0],
indices_out=edge_index[1],
size_out=x.shape[0]).contiguous()
m_cue = self.tp_to_hidden(m_cue, weight_indices=atom_class)
sc_out = self.sym_cont(m_cue, atom_class)
out_cue = self.linear(sc_out, weight_indices=atom_class)
x = self.from_cue(out_cue)
return x
I keep running into seperate issues when trying to torch.jit.script() cueequivariance modules such as:
RuntimeError:
Variable 'indices_out' previously had type Optional[Tensor] but is now being assigned to a value of type Dict[int, Tensor]
:
File "/nas/longleaf/home/rdey/anaconda3/envs/solv3/lib/python3.10/site-packages/cuequivariance_torch/operations/tp_channel_wise.py", line 230
indices_in[2] = indices_2
if indices_out is not None:
indices_out = {0: indices_out}
~~~~~~~~~~~ <--- HERE
if size_out is None:
raise ValueError(
I got several of these from SphericalHarmonics, ChannelWiseTensorProduct, SegmentedPolynomial, etc.
While these are easy fixes for me to make manually, it does remove some production-level readiness for the application. However, I keep seeing that cueequivariance is fully JIT and torch.compile compilable. As such, am I missing something in particular?
Hi there,
I was trying to jit compile on version 0.9.1 from scratch and I kept running into numerous errors regarding it. Here is the full python script:
I keep running into seperate issues when trying to torch.jit.script() cueequivariance modules such as:
I got several of these from SphericalHarmonics, ChannelWiseTensorProduct, SegmentedPolynomial, etc.
While these are easy fixes for me to make manually, it does remove some production-level readiness for the application. However, I keep seeing that cueequivariance is fully JIT and torch.compile compilable. As such, am I missing something in particular?