From 7153312bb827d1c2eb8db6bc5151c8e37107531b Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Sun, 1 Mar 2026 01:16:02 +0900 Subject: [PATCH 01/10] add openequivariance --- CHANGELOG.md | 1 + pyproject.toml | 1 + sevenn/_const.py | 2 + sevenn/_keys.py | 1 + sevenn/calculator.py | 17 +- sevenn/checkpoint.py | 9 +- sevenn/main/sevenn.py | 18 ++- sevenn/main/sevenn_get_model.py | 17 ++ sevenn/main/sevenn_inference.py | 37 +++++ sevenn/mliap.py | 5 +- sevenn/model_build.py | 25 +++ sevenn/nn/oeq_helper.py | 83 ++++++++++ sevenn/scripts/inference.py | 11 +- sevenn/util.py | 5 +- tests/lammps_tests/test_mliap.py | 59 +++++++ tests/unit_tests/test_oeq.py | 261 +++++++++++++++++++++++++++++++ 16 files changed, 541 insertions(+), 11 deletions(-) create mode 100644 sevenn/nn/oeq_helper.py create mode 100644 tests/unit_tests/test_oeq.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e08c4f8..c9d3cbba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file. ### Added - TorchSim interface +- Support OpenEquivariance ## [0.12.0] ### Added diff --git a/pyproject.toml b/pyproject.toml index c9a924ba..a5f37540 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ test = ["pytest", "pytest-cov>=5", "ipython"] torchsim = ["torch-sim-atomistic>=0.5.2; python_version >= '3.12'"] cueq12 = ["cuequivariance>=0.6.0; python_version >= '3.10'", "cuequivariance-torch>=0.6.0; python_version >= '3.10'", "cuequivariance-ops-torch-cu12; python_version >= '3.10'"] +oeq = ["openequivariance"] mliap = ["cython==3.0.11", "cupy-cuda12x"] [project.scripts] diff --git a/sevenn/_const.py b/sevenn/_const.py index 0004e68c..6dc45589 100644 --- a/sevenn/_const.py +++ b/sevenn/_const.py @@ -131,6 +131,7 @@ def error_record_condition(x): KEY._NORMALIZE_SPH: True, KEY.USE_FLASH_TP: False, KEY.CUEQUIVARIANCE_CONFIG: {}, + KEY.USE_OEQ: False, } @@ -178,6 +179,7 @@ def error_record_condition(x): KEY._NORMALIZE_SPH: bool, KEY.USE_FLASH_TP: bool, KEY.CUEQUIVARIANCE_CONFIG: dict, + KEY.USE_OEQ: bool, } diff --git a/sevenn/_keys.py b/sevenn/_keys.py index 2f430b1b..0c9af7b7 100644 --- a/sevenn/_keys.py +++ b/sevenn/_keys.py @@ -222,6 +222,7 @@ USE_FLASH_TP = 'use_flash_tp' CUEQUIVARIANCE_CONFIG = 'cuequivariance_config' +USE_OEQ = 'use_oeq' _NORMALIZE_SPH = '_normalize_sph' OPTIMIZE_BY_REDUCE = 'optimize_by_reduce' diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 9515ec7f..8ec2b8ea 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -39,6 +39,7 @@ def __init__( modal: Optional[str] = None, enable_cueq: Optional[bool] = False, enable_flash: Optional[bool] = False, + enable_oeq: Optional[bool] = False, sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info **kwargs, ) -> None: @@ -64,6 +65,8 @@ def __init__( if True, use cuEquivariant to accelerate inference. enable_flash: bool, default=None (use the checkpoint's backend) if True, use FlashTP to accelerate inference. + enable_oeq: bool, default=False + if True, use OpenEquivariance to accelerate inference. sevennet_config: dict | None, default=None Not used, but can be used to carry meta information of this calculator """ @@ -80,10 +83,7 @@ def __init__( enable_cueq = os.getenv('SEVENNET_ENABLE_CUEQ') == '1' or enable_cueq enable_flash = os.getenv('SEVENNET_ENABLE_FLASH') == '1' or enable_flash - print('cueq') - print(enable_cueq) - print('flash') - print(enable_flash) + enable_oeq = os.getenv('SEVENNET_ENABLE_OEQ') == '1' or enable_oeq if enable_cueq and file_type in ['model_instance', 'torchscript']: warnings.warn( @@ -91,6 +91,13 @@ def __init__( ) enable_cueq = False + # TODO: not verified this line + if enable_oeq and file_type in ['model_instance', 'torchscript']: + warnings.warn( + 'file_type should be checkpoint to enable oeq. oeq set to False' + ) + enable_oeq = False + if isinstance(device, str): # TODO: do we really need this? if device == 'auto': self.device = torch.device( @@ -105,7 +112,7 @@ def __init__( cp = util.load_checkpoint(model) model_loaded = cp.build_model( - enable_cueq=enable_cueq, enable_flash=enable_flash + enable_cueq=enable_cueq, enable_flash=enable_flash, enable_oeq=enable_oeq ) model_loaded.set_is_batch_data(False) diff --git a/sevenn/checkpoint.py b/sevenn/checkpoint.py index ec95cffa..a87e1212 100644 --- a/sevenn/checkpoint.py +++ b/sevenn/checkpoint.py @@ -312,6 +312,7 @@ def build_model( *, enable_cueq: Optional[bool] = None, enable_flash: Optional[bool] = None, + enable_oeq: Optional[bool] = None, _flash_lammps: bool = False, ) -> AtomGraphSequential: """ @@ -328,11 +329,16 @@ def build_model( cp_using_flash = self.config.get(KEY.USE_FLASH_TP, False) enable_flash = cp_using_flash if enable_flash is None else enable_flash + cp_using_oeq = self.config.get(KEY.USE_OEQ, False) + enable_oeq = cp_using_oeq if enable_oeq is None else enable_oeq + assert not _flash_lammps or enable_flash cfg_new = self.config cfg_new['_flash_lammps'] = _flash_lammps + cfg_new[KEY.USE_OEQ] = enable_oeq - if (cp_using_cueq, cp_using_flash) == (enable_cueq, enable_flash): + if (cp_using_cueq, cp_using_flash, cp_using_oeq) \ + == (enable_cueq, enable_flash, enable_oeq): # backend not given, or checkpoint backend is same as requested model = build_E3_equivariant_model(cfg_new) state_dict = compat.patch_state_dict_if_old( @@ -347,6 +353,7 @@ def build_model( cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': enable_cueq} cfg_new[KEY.USE_FLASH_TP] = enable_flash + cfg_new[KEY.USE_OEQ] = enable_oeq model = build_E3_equivariant_model(cfg_new) stct_src = compat.patch_state_dict_if_old( self.model_state_dict, self.config, model diff --git a/sevenn/main/sevenn.py b/sevenn/main/sevenn.py index acd1faff..43da4a36 100644 --- a/sevenn/main/sevenn.py +++ b/sevenn/main/sevenn.py @@ -47,12 +47,13 @@ def run(args): distributed_backend = args.distributed_backend use_cue = args.enable_cueq use_flash = args.enable_flash + use_oeq = args.enable_oeq if use_cue: import sevenn.nn.cue_helper if not sevenn.nn.cue_helper.is_cue_available(): - raise ImportError('cuEquivariance not installed.') + raise ImportError('cuEquivariance not installed or no GPU found.') if use_flash: import sevenn.nn.flash_helper @@ -60,6 +61,12 @@ def run(args): if not sevenn.nn.flash_helper.is_flash_available(): raise ImportError('FlashTP not installed or no GPU found.') + if use_oeq: + import sevenn.nn.oeq_helper + + if not sevenn.nn.oeq_helper.is_oeq_available(): + raise ImportError('OpenEquivariance not installed or no GPU found.') + if working_dir is None: working_dir = os.getcwd() elif not os.path.isdir(working_dir): @@ -118,6 +125,9 @@ def run(args): if use_flash: model_config[KEY.USE_FLASH_TP] = True + if use_oeq: + model_config[KEY.USE_OEQ] = True + logger.print_config(model_config, data_config, train_config) # don't have to distinguish configs inside program global_config.update(model_config) @@ -165,6 +175,12 @@ def cmd_parser_train(parser): help='use flashTP accelerations for training', action='store_true', ) + ag.add_argument( + '-oeq', + '--enable_oeq', + help='use OpenEquivariance accelerations for training', + action='store_true', + ) ag.add_argument( '-w', '--working_dir', diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index 04132b30..edbff46c 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -52,6 +52,12 @@ def add_args(parser): help='use cuEquivariance. Only support ML-IAP interface.', action='store_true', ) + ag.add_argument( + '-oeq', + '--enable_oeq', + help='use OpenEquivariance. Only support ML-IAP interface.', + action='store_true', + ) ag.add_argument( '-mliap', '--use_mliap', @@ -70,6 +76,7 @@ def run(args): modal = args.modal use_flash = args.enable_flash use_cueq = args.enable_cueq + use_oeq = args.enable_oeq use_mliap = args.use_mliap # Check dependencies @@ -85,9 +92,18 @@ def run(args): if not is_cue_available(): raise ImportError('cuEquivariance is not installed.') + if use_oeq: + from sevenn.nn.oeq_helper import is_oeq_available + + if not is_oeq_available(): + raise ImportError('OpenEquivariance not installed or no GPU found.') + if use_cueq and not use_mliap: raise ValueError('cuEquivariance is only supported in ML-IAP interface.') + if use_oeq and not use_mliap: + raise ValueError('OpenEquivariance is only supported in ML-IAP interface.') + if use_mliap and get_parallel: raise ValueError('Currently, ML-IAP interface does not tested on parallel.') @@ -122,6 +138,7 @@ def run(args): modal=modal, use_cueq=use_cueq, use_flash=use_flash, + use_oeq=use_oeq, ) torch.save(mliap_module, output_prefix) diff --git a/sevenn/main/sevenn_inference.py b/sevenn/main/sevenn_inference.py index bfac2371..e2531943 100644 --- a/sevenn/main/sevenn_inference.py +++ b/sevenn/main/sevenn_inference.py @@ -66,6 +66,25 @@ def add_args(parser): default=None, help='modality for multi-modal inference', ) + ag.add_argument( + '-cueq', + '--enable_cueq', + help='use cuEquivariance to accelerate inference', + action='store_true', + ) + ag.add_argument( + '-flashTP', + '--enable_flash', + dest='enable_flash', + help='use FlashTP to accelerate inference', + action='store_true', + ) + ag.add_argument( + '-oeq', + '--enable_oeq', + help='use OpenEquivariance to accelerate inference', + action='store_true', + ) ag.add_argument( '--kwargs', nargs=argparse.REMAINDER, @@ -109,6 +128,21 @@ def run(args): if args.save_graph and args.allow_unlabeled: raise ValueError('save_graph and allow_unlabeled are mutually exclusive') + if args.enable_cueq: + from sevenn.nn.cue_helper import is_cue_available + if not is_cue_available(): + raise ImportError('cuEquivariance not installed or no GPU found.') + + if args.enable_flash: + from sevenn.nn.flash_helper import is_flash_available + if not is_flash_available(): + raise ImportError('FlashTP not installed or no GPU found.') + + if args.enable_oeq: + from sevenn.nn.oeq_helper import is_oeq_available + if not is_oeq_available(): + raise ImportError('OpenEquivariance not installed or no GPU found.') + inference( cp, targets, @@ -119,6 +153,9 @@ def run(args): args.save_graph, args.allow_unlabeled, args.modal, + enable_cueq=args.enable_cueq, + enable_flash=args.enable_flash, + enable_oeq=args.enable_oeq, **fmt_kwargs, ) diff --git a/sevenn/mliap.py b/sevenn/mliap.py index 29bff117..2be00873 100644 --- a/sevenn/mliap.py +++ b/sevenn/mliap.py @@ -111,6 +111,7 @@ def __init__( # calc_kwargs self.use_cueq = kwargs.get('use_cueq', False) self.use_flash = kwargs.get('use_flash', False) + self.use_oeq = kwargs.get('use_oeq', False) self.modal = kwargs.get('modal', None) # extract configs @@ -150,9 +151,9 @@ def _ensure_model_initialized(self): if self.model is not None: return # Already initialized print('[INFO] Lazy initializing SevenNet model...', flush=True) - print(f'[INFO] cueq={self.use_cueq}, flashTP={self.use_flash}', flush=True) + print(f'[INFO] cueq={self.use_cueq}, flashTP={self.use_flash}, oeq={self.use_oeq}', flush=True) model = self.cp.build_model( - enable_cueq=self.use_cueq, enable_flash=self.use_flash + enable_cueq=self.use_cueq, enable_flash=self.use_flash, enable_oeq=self.use_oeq ) for k, module in model._modules.items(): diff --git a/sevenn/model_build.py b/sevenn/model_build.py index 3117ce29..c548c34e 100644 --- a/sevenn/model_build.py +++ b/sevenn/model_build.py @@ -326,10 +326,35 @@ def patch_flash_tp(layers: OrderedDict, config: Dict[str, Any]) -> OrderedDict: return layers +def patch_oeq(layers: OrderedDict, config: Dict[str, Any]) -> OrderedDict: + import sevenn.nn.oeq_helper as oeq_helper + + if not config.get(KEY.USE_OEQ, False): + return layers + + if not oeq_helper.is_oeq_available(): + warnings.warn( + ( + 'OpenEquivariance (oeq) is requested, but the package is not ' + 'installed or GPU not available. Fallback to e3nn.' + ) + ) + return layers + + updates = {} + for k, module in layers.items(): + if isinstance(module, IrrepsConvolution): + updates[k] = oeq_helper.patch_convolution(module) + + layers.update(updates) + return layers + + def patch_modules(layers: OrderedDict, config: Dict[str, Any]) -> OrderedDict: layers = patch_modality(layers, config) layers = patch_cue(layers, config) layers = patch_flash_tp(layers, config) + layers = patch_oeq(layers, config) return layers diff --git a/sevenn/nn/oeq_helper.py b/sevenn/nn/oeq_helper.py new file mode 100644 index 00000000..f29a8214 --- /dev/null +++ b/sevenn/nn/oeq_helper.py @@ -0,0 +1,83 @@ +import torch + +from .convolution import IrrepsConvolution, IrrepsScatterGatterFusedConvolution + +try: + from openequivariance import ( + TensorProductConv, + TPProblem, + torch_to_oeq_dtype, + ) + + _OEQ_AVAILABLE = True +except ImportError: + _OEQ_AVAILABLE = False + + +def is_oeq_available() -> bool: + return _OEQ_AVAILABLE and torch.cuda.is_available() + + +def oeq_needed(func): + def wrapper(*args, **kwargs): + if is_oeq_available(): + return func(*args, **kwargs) + raise ImportError('OpenEquivariance (oeq) is not available') + + return wrapper + + +class OEQConvolution(torch.nn.Module): + """ + Wrapper around openequivariance.TensorProductConv to match + IrrepsScatterGatterFusedConvolution.convolution_cls interface. + """ + + def __init__( + self, + irreps_in1, + irreps_in2, + irreps_out, + instructions, + shared_weights: bool = False, + internal_weights: bool = False, + ): + super().__init__() + if not is_oeq_available(): + raise ImportError('OpenEquivariance is not available') + self.dtype = torch.get_default_dtype() + tpp = TPProblem( + irreps_in1, + irreps_in2, + irreps_out, + instructions, + irrep_dtype=torch_to_oeq_dtype(self.dtype), + weight_dtype=torch_to_oeq_dtype(self.dtype), + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + self.tp_conv = TensorProductConv(tpp, torch_op=True, deterministic=False) + + def forward(self, x, edge_filter, weight, edge_src, edge_dst): + # OEQ rows=dst, cols=src (swapped vs sevennet's arg order) + # OEQ needs int64; sevennet casts to int32 before convolution + return self.tp_conv( + x.to(self.dtype), + edge_filter.to(self.dtype), + weight.to(self.dtype), + edge_dst.to(torch.int64), # rows = dst + edge_src.to(torch.int64), # cols = src + ) + + +@oeq_needed +def patch_convolution(irreps_convolution: IrrepsConvolution): + + assert not irreps_convolution.layer_instantiated, ( + 'Convolution layer already instantiated; cannot patch' + ) + ret = IrrepsScatterGatterFusedConvolution.from_irreps_convolution( + irreps_convolution + ) + ret.convolution_cls = OEQConvolution + return ret diff --git a/sevenn/scripts/inference.py b/sevenn/scripts/inference.py index a7e99f77..c3e83428 100644 --- a/sevenn/scripts/inference.py +++ b/sevenn/scripts/inference.py @@ -125,6 +125,9 @@ def inference( save_graph: bool = False, allow_unlabeled: bool = False, modal: Optional[str] = None, + enable_cueq: bool = False, + enable_flash: bool = False, + enable_oeq: bool = False, **data_kwargs, ) -> None: """ @@ -151,7 +154,13 @@ def inference( at once, it will not work smoothly with data_kwargs """ - model, _ = util.model_from_checkpoint(checkpoint) + # TODO: False as default, priority? + model, _ = util.model_from_checkpoint( + checkpoint, + enable_cueq=enable_cueq or None, + enable_flash=enable_flash or None, + enable_oeq=enable_oeq or None, + ) cutoff = model.cutoff if modal: diff --git a/sevenn/util.py b/sevenn/util.py index 9bff5c89..d7ce7888 100644 --- a/sevenn/util.py +++ b/sevenn/util.py @@ -100,9 +100,12 @@ def model_from_checkpoint( *, enable_cueq: Optional[bool] = None, enable_flash: Optional[bool] = None, + enable_oeq: Optional[bool] = None, ) -> Tuple[torch.nn.Module, Dict[str, Any]]: cp = load_checkpoint(checkpoint) - model = cp.build_model(enable_cueq=enable_cueq, enable_flash=enable_flash) + model = cp.build_model( + enable_cueq=enable_cueq, enable_flash=enable_flash, enable_oeq=enable_oeq + ) return model, cp.config diff --git a/tests/lammps_tests/test_mliap.py b/tests/lammps_tests/test_mliap.py index 4c048bfb..e087659b 100644 --- a/tests/lammps_tests/test_mliap.py +++ b/tests/lammps_tests/test_mliap.py @@ -18,6 +18,7 @@ from sevenn.main.sevenn_get_model import main as mliap_cli from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available +from sevenn.nn.oeq_helper import is_oeq_available from sevenn.util import pretrained_name_to_path try: @@ -103,6 +104,22 @@ def mliap_flash_potential_path(tmp_path_factory): return pt_path +@pytest.fixture(scope='module') +def mliap_oeq_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('mliap_oeq_pt') + pt_path = str((tmp / 'oeq.pt').resolve()) + + argv = [ + 'sevenn_get_model', + cp_7net0_path, '-o', pt_path, + '--enable_oeq', + '--use_mliap' + ] + with mock.patch('sys.argv', argv): + mliap_cli() + return pt_path + + @pytest.fixture(scope='module') def ref_7net0_calculator(): return SevenNetCalculator(cp_7net0_path) @@ -358,6 +375,27 @@ def test_mliap_flash( assert_atoms(atoms, atoms_lmp, atol=1e-5) +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize('system', ['bulk', 'surface']) +def test_mliap_oeq( + system, + mliap_oeq_potential_path, + ref_7net0_calculator, + lammps_cmd, + tmp_path +): + atoms = get_system(system) + atoms_lmp = mliap_lammps_run( + atoms=atoms, + pt_path=mliap_oeq_potential_path, + wd=tmp_path, + test_name='mliap oeq', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_calculator + assert_atoms(atoms, atoms_lmp, atol=1e-5) + + def _get_disconnected_system(name): """Build Atoms object for a disconnected test structure.""" if name.startswith('single_o'): @@ -423,6 +461,27 @@ def test_mliap_flash_disconnected( assert_atoms(atoms, atoms_lmp, atol=1e-5) +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize('system', _disconnected_systems) +def test_mliap_oeq_disconnected( + system, + mliap_oeq_potential_path, + ref_7net0_calculator, + lammps_cmd, + tmp_path, +): + atoms = _get_disconnected_system(system) + atoms_lmp = mliap_lammps_run( + atoms=atoms, + pt_path=mliap_oeq_potential_path, + wd=tmp_path, + test_name=f'mliap oeq disconnected {system}', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_calculator + assert_atoms(atoms, atoms_lmp, atol=1e-5) + + @pytest.mark.skipif(not is_cue_available(), reason='cueq not available') @pytest.mark.parametrize('system', _disconnected_systems) def test_mliap_cueq_disconnected( diff --git a/tests/unit_tests/test_oeq.py b/tests/unit_tests/test_oeq.py new file mode 100644 index 00000000..f7f3f2e7 --- /dev/null +++ b/tests/unit_tests/test_oeq.py @@ -0,0 +1,261 @@ +# TODO: add gradient test from total loss after double precision. +# so far, it is empirically checked by seeing learning curves +import copy + +import numpy as np +import pytest +import torch +from ase.build import bulk +from torch_geometric.loader.dataloader import Collater + +import sevenn +import sevenn.train.dataload as dl +from sevenn.atom_graph_data import AtomGraphData +from sevenn.calculator import SevenNetCalculator +from sevenn.model_build import build_E3_equivariant_model +from sevenn.nn.oeq_helper import is_oeq_available +from sevenn.nn.sequential import AtomGraphSequential +from sevenn.util import chemical_species_preprocess, model_from_checkpoint + +cutoff = 4.0 + +_atoms = bulk('NaCl', 'rocksalt', a=4.00) * (2, 2, 2) +_avg_num_neigh = 30.0 +_atoms.rattle() + +_graph = AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(_atoms, cutoff)) + + +def get_graphs(batched): + # batch size 2 + cloned = [_graph.clone().to('cuda'), _graph.clone().to('cuda')] + if not batched: + return cloned + else: + return Collater(cloned)(cloned) + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 32, + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'self_connection_type': 'nequip', # not NequIp + 'interaction_type': 'nequip', + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': _avg_num_neigh, + 'train_denominator': False, + 'shift': -10.0, + 'scale': 10.0, + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + chems = set() + chems.update(_atoms.get_chemical_symbols()) + config.update(**chemical_species_preprocess(list(chems))) + return config + + +def get_model(config_overwrite=None, use_oeq=False): + cf = get_model_config() + if config_overwrite is not None: + cf.update(config_overwrite) + + cf['use_oeq'] = use_oeq + + model = build_E3_equivariant_model(cf, parallel=False) + assert isinstance(model, AtomGraphSequential) + model.to('cuda') + return model + + +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize( + 'cf', + [ + ({}), + ({'lmax': 3}), + ({'num_interaction_layer': 2}), + ({'num_interaction_layer': 4}), + ], +) +def test_model_output(cf): + torch.manual_seed(777) + model_e3nn = get_model(cf) + torch.manual_seed(777) + model_oeq = get_model(cf, use_oeq=True) + + model_e3nn.set_is_batch_data(True) + model_oeq.set_is_batch_data(True) + + e3nn_out = model_e3nn._preprocess(get_graphs(batched=True)) + oeq_out = model_oeq._preprocess(get_graphs(batched=True)) + + for k, e3nn_f in model_e3nn._modules.items(): + oeq_f = model_oeq._modules[k] + e3nn_out = e3nn_f(e3nn_out) # type: ignore + oeq_out = oeq_f(oeq_out) # type: ignore + assert torch.allclose(e3nn_out.x, oeq_out.x, atol=1e-6), ( + f'{k} \n\n {e3nn_f} \n\n {oeq_f}' + ) + + assert torch.allclose( + e3nn_out.inferred_total_energy, oeq_out.inferred_total_energy + ) + assert torch.allclose(e3nn_out.atomic_energy, oeq_out.atomic_energy) + assert torch.allclose( + e3nn_out.inferred_force, oeq_out.inferred_force, atol=1e-5 + ) + assert torch.allclose( + e3nn_out.inferred_stress, oeq_out.inferred_stress, atol=1e-5 + ) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize( + 'start_from_oeq', + [ + (True), + (False), + ], +) +def test_checkpoint_convert(tmp_path, start_from_oeq): + torch.manual_seed(123) + model_from = get_model(use_oeq=start_from_oeq) + + cfg = get_model_config() + cfg.update( + { + 'use_oeq': start_from_oeq, + 'version': sevenn.__version__, + } + ) + torch.save( + {'model_state_dict': model_from.state_dict(), 'config': cfg}, + tmp_path / 'cp_from.pth', + ) + + model_to, _ = model_from_checkpoint( + str(tmp_path / 'cp_from.pth'), enable_oeq=(not start_from_oeq) + ) + model_to.to('cuda') + + model_from.set_is_batch_data(True) + model_to.set_is_batch_data(True) + + from_out = model_from(get_graphs(batched=True)) + to_out = model_to(get_graphs(batched=True)) + + assert torch.allclose( + from_out.inferred_total_energy, to_out.inferred_total_energy + ) + assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) + assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) + assert torch.allclose( + from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 + ) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize( + 'start_from_oeq', + [ + (True), + (False), + ], +) +def test_checkpoint_convert_no_batch(tmp_path, start_from_oeq): + torch.manual_seed(123) + model_from = get_model(use_oeq=start_from_oeq) + + cfg = get_model_config() + cfg.update( + { + 'use_oeq': start_from_oeq, + 'version': sevenn.__version__, + } + ) + torch.save( + {'model_state_dict': model_from.state_dict(), 'config': cfg}, + tmp_path / 'cp_from.pth', + ) + + model_to, _ = model_from_checkpoint( + str(tmp_path / 'cp_from.pth'), enable_oeq=(not start_from_oeq) + ) + model_to.to('cuda') + + model_from.set_is_batch_data(False) + model_to.set_is_batch_data(False) + + from_out = model_from(get_graphs(batched=False)[0]) + to_out = model_to(get_graphs(batched=False)[0]) + + assert torch.allclose( + from_out.inferred_total_energy, to_out.inferred_total_energy + ) + assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) + assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) + assert torch.allclose( + from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 + ) + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +def test_calculator(tmp_path): + oeq = True + model = get_model(use_oeq=oeq) + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = copy.deepcopy(_atoms) + atoms.calc = ref_calc + + cfg = get_model_config() + cfg.update( + {'use_oeq': oeq, 'version': sevenn.__version__} + ) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + calc2 = SevenNetCalculator(cp_path, enable_oeq=False) + atoms2 = copy.deepcopy(_atoms) + atoms2.calc = calc2 + + assert_atoms(atoms, atoms2) From a903db183f7653bc5ec3e70604d21254ceb5156f Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Tue, 3 Mar 2026 17:34:42 +0900 Subject: [PATCH 02/10] lammps pair_e3gnn porting for oeq --- sevenn/main/sevenn_get_model.py | 9 +- sevenn/main/sevenn_patch_lammps.py | 25 +++ sevenn/pair_e3gnn/pair_e3gnn.cpp | 11 ++ sevenn/pair_e3gnn/pair_e3gnn_oeq_autograd.cpp | 163 ++++++++++++++++++ sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp | 11 ++ sevenn/pair_e3gnn/patch_lammps.sh | 84 ++++++--- sevenn/scripts/deploy.py | 17 +- tests/lammps_tests/test_lammps.py | 114 ++++++++++++ 8 files changed, 404 insertions(+), 30 deletions(-) create mode 100644 sevenn/pair_e3gnn/pair_e3gnn_oeq_autograd.cpp diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index edbff46c..1785d18b 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -55,7 +55,7 @@ def add_args(parser): ag.add_argument( '-oeq', '--enable_oeq', - help='use OpenEquivariance. Only support ML-IAP interface.', + help='use OpenEquivariance. LAMMPS must be specially compiled.', action='store_true', ) ag.add_argument( @@ -101,9 +101,6 @@ def run(args): if use_cueq and not use_mliap: raise ValueError('cuEquivariance is only supported in ML-IAP interface.') - if use_oeq and not use_mliap: - raise ValueError('OpenEquivariance is only supported in ML-IAP interface.') - if use_mliap and get_parallel: raise ValueError('Currently, ML-IAP interface does not tested on parallel.') @@ -124,9 +121,9 @@ def run(args): from sevenn.scripts.deploy import deploy, deploy_parallel if get_serial: - deploy(checkpoint_path, output_prefix, modal, use_flash=use_flash) + deploy(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 else: - deploy_parallel(checkpoint_path, output_prefix, modal, use_flash=use_flash) # noqa: E501 + deploy_parallel(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 else: from sevenn import mliap diff --git a/sevenn/main/sevenn_patch_lammps.py b/sevenn/main/sevenn_patch_lammps.py index 641fa58c..e90b1c78 100644 --- a/sevenn/main/sevenn_patch_lammps.py +++ b/sevenn/main/sevenn_patch_lammps.py @@ -29,6 +29,13 @@ def add_args(parser): help='Enable flashTP', action='store_true', ) + ag.add_argument( + '--oeq', + '--enable_oeq', + dest='enable_oeq', + help='Enable OpenEquivariance', + action='store_true', + ) # cxx_standard is detected automatically @@ -47,6 +54,18 @@ def run(args): d3_support = '0' print(' - D3 support disabled') + so_oeq = '' + if args.enable_oeq: + try: + from openequivariance._torch.extlib import torch_ext_so_path + except ImportError: + raise ImportError('OpenEquivariance import failed.') + + so_oeq = torch_ext_so_path() + if not osp.isfile(so_oeq): + raise ValueError(f'OEQ .so file not found: {so_oeq}') + print(f' - OEQ support enabled: {so_oeq}') + so_lammps = '' if args.enable_flash: try: @@ -86,6 +105,12 @@ def run(args): if args.enable_flash: assert osp.isfile(so_lammps) cmd += f' {so_lammps}' + else: + cmd += ' NONE' + + if args.enable_oeq: + assert osp.isfile(so_oeq) + cmd += f' {so_oeq}' res = subprocess.run(cmd.split()) return res.returncode # is it meaningless? diff --git a/sevenn/pair_e3gnn/pair_e3gnn.cpp b/sevenn/pair_e3gnn/pair_e3gnn.cpp index 3a4396dc..3e99b80b 100644 --- a/sevenn/pair_e3gnn/pair_e3gnn.cpp +++ b/sevenn/pair_e3gnn/pair_e3gnn.cpp @@ -36,6 +36,9 @@ using namespace LAMMPS_NS; +// Undefined reference; body in pair_e3gnn_oeq_autograd.cpp to be linked +extern void pair_e3gnn_oeq_register_autograd(); + #define INTEGER_TYPE torch::TensorOptions().dtype(torch::kInt64) #define FLOAT_TYPE torch::TensorOptions().dtype(torch::kFloat) @@ -311,6 +314,7 @@ void PairE3GNN::coeff(int narg, char **arg) { {"version", ""}, {"dtype", ""}, {"flashTP", "version mismatch"}, + {"oeq", "version mismatch"}, {"time", ""}}; // model loading from input @@ -331,6 +335,11 @@ void PairE3GNN::coeff(int narg, char **arg) { cutoff = std::stod(meta_dict["cutoff"]); cutoff_square = cutoff * cutoff; + // to make torch::autograd::grad() works + if (meta_dict["oeq"] == "yes") { + pair_e3gnn_oeq_register_autograd(); + } + if (meta_dict["model_type"].compare("E3_equivariant_model") != 0) { error->all(FLERR, "given model type is not E3_equivariant_model"); } @@ -384,6 +393,8 @@ void PairE3GNN::coeff(int narg, char **arg) { meta_dict["dtype"].c_str(), meta_dict["time"].c_str()); fprintf(lmp->logfile, "FlashTP: %s\n", meta_dict["flashTP"].c_str()); + fprintf(lmp->logfile, "OEQ: %s\n", + meta_dict["oeq"].c_str()); } } diff --git a/sevenn/pair_e3gnn/pair_e3gnn_oeq_autograd.cpp b/sevenn/pair_e3gnn/pair_e3gnn_oeq_autograd.cpp new file mode 100644 index 00000000..236b447e --- /dev/null +++ b/sevenn/pair_e3gnn/pair_e3gnn_oeq_autograd.cpp @@ -0,0 +1,163 @@ +/* ----------------------------------------------------------------------- + pair_e3gnn_oeq_autograd.cpp + + It registers the Autograd dispatch key for libtorch_tp_jit::jit_conv_forward + so that pair_e3gnn can compute forces via autograd in pure C++ LibTorch. + In Python, these registrations happen in: + openequivariance._torch.TensorProductConv (jit_conv_forward) + + Contributing author: Jinmu You (SNU) + ----------------------------------------------------------------------- */ + +#include +#include +#include +#include +#include + +using namespace torch::autograd; + +// --------------------------------------------------------------------------- +// Helper: call libtorch_tp_jit::jit_conv_backward through the dispatcher. +// Signature: +// jit_conv_backward(Tensor json_bytes, int hash, +// Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, +// Tensor rows, Tensor cols, +// Tensor workspace, Tensor transpose_perm) +// -> (Tensor L1_grad, Tensor L2_grad, Tensor W_grad) +// --------------------------------------------------------------------------- +static std::tuple dispatch_jit_conv_backward( + const at::Tensor& json_bytes, int64_t hash, + const at::Tensor& L1_in, const at::Tensor& L2_in, + const at::Tensor& W, const at::Tensor& L3_grad, + const at::Tensor& rows, const at::Tensor& cols, + const at::Tensor& workspace, const at::Tensor& transpose_perm) +{ + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("libtorch_tp_jit::jit_conv_backward", ""); + c10::Stack stack; + stack.reserve(10); + stack.emplace_back(json_bytes); + stack.emplace_back(hash); + stack.emplace_back(L1_in); + stack.emplace_back(L2_in); + stack.emplace_back(W); + stack.emplace_back(L3_grad); + stack.emplace_back(rows); + stack.emplace_back(cols); + stack.emplace_back(workspace); + stack.emplace_back(transpose_perm); + + // Exclude autograd keys so callBoxed goes straight to CUDA. + c10::impl::ExcludeDispatchKeyGuard no_autograd( + c10::autograd_dispatch_keyset); + op.callBoxed(&stack); + + TORCH_INTERNAL_ASSERT(stack.size() == 3, + "jit_conv_backward: expected 3 output tensors, got ", stack.size()); + return {stack[0].toTensor(), stack[1].toTensor(), stack[2].toTensor()}; +} + +// --------------------------------------------------------------------------- +// Custom autograd Function wrapping jit_conv_forward. +// --------------------------------------------------------------------------- +struct JitConvAutograd : public Function { + static at::Tensor forward( + AutogradContext* ctx, + at::Tensor json_bytes, int64_t hash, + at::Tensor L1_in, at::Tensor L2_in, at::Tensor W, + int64_t L3_dim, + at::Tensor rows, at::Tensor cols, + at::Tensor workspace, at::Tensor transpose_perm) + { + ctx->save_for_backward( + {json_bytes, L1_in, L2_in, W, rows, cols, workspace, transpose_perm}); + ctx->saved_data["hash"] = hash; + ctx->saved_data["L3_dim"] = L3_dim; + + static auto fwd_op = c10::Dispatcher::singleton() + .findSchemaOrThrow("libtorch_tp_jit::jit_conv_forward", ""); + c10::Stack stack; + stack.reserve(10); + stack.emplace_back(json_bytes); + stack.emplace_back(hash); + stack.emplace_back(L1_in); + stack.emplace_back(L2_in); + stack.emplace_back(W); + stack.emplace_back(L3_dim); + stack.emplace_back(rows); + stack.emplace_back(cols); + stack.emplace_back(workspace); + stack.emplace_back(transpose_perm); + + c10::impl::ExcludeDispatchKeyGuard no_autograd( + c10::autograd_dispatch_keyset); + fwd_op.callBoxed(&stack); + + TORCH_INTERNAL_ASSERT(stack.size() == 1, + "jit_conv_forward: expected 1 output tensor, got ", stack.size()); + return stack[0].toTensor(); + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) + { + const auto& saved = ctx->get_saved_variables(); + auto json_bytes = saved[0]; + auto L1_in = saved[1]; + auto L2_in = saved[2]; + auto W = saved[3]; + auto rows = saved[4]; + auto cols = saved[5]; + auto workspace = saved[6]; + auto transpose_perm = saved[7]; + auto hash = ctx->saved_data["hash"].toInt(); + auto L3_grad = grad_outputs[0]; + + auto [L1_grad, L2_grad, W_grad] = dispatch_jit_conv_backward( + json_bytes, hash, L1_in, L2_in, W, L3_grad, + rows, cols, workspace, transpose_perm); + + return { + at::Tensor(), // json_bytes (no grad) + at::Tensor(), // hash (no grad) + L1_grad, // L1_in + L2_grad, // L2_in + W_grad, // W + at::Tensor(), // L3_dim (no grad) + at::Tensor(), // rows (no grad) + at::Tensor(), // cols (no grad) + at::Tensor(), // workspace (no grad) + at::Tensor(), // transpose_perm (no grad) + }; + } +}; + +// --------------------------------------------------------------------------- +// pair_e3gnn_oeq_register_autograd() +// +// Called from PairE3GNN::coeff() when "oeq: yes". +// Similar to Python's torch.library.register_autograd(). +// --------------------------------------------------------------------------- +void pair_e3gnn_oeq_register_autograd() +{ + static torch::Library lib = [] { + torch::Library l( + torch::Library::IMPL, + "libtorch_tp_jit", + std::optional(c10::DispatchKey::Autograd), + __FILE__, + __LINE__); + l.impl( + "jit_conv_forward", + [](at::Tensor json_bytes, int64_t hash, + at::Tensor L1_in, at::Tensor L2_in, at::Tensor W, + int64_t L3_dim, + at::Tensor rows, at::Tensor cols, + at::Tensor workspace, at::Tensor transpose_perm) -> at::Tensor { + return JitConvAutograd::apply( + json_bytes, hash, L1_in, L2_in, W, L3_dim, + rows, cols, workspace, transpose_perm); + }); + return std::move(l); + }(); +} diff --git a/sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp b/sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp index a31d4b89..de8c6a11 100644 --- a/sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp +++ b/sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp @@ -50,6 +50,9 @@ using namespace LAMMPS_NS; +// Undefined reference; body in pair_e3gnn_oeq_autograd.cpp to be linked +extern void pair_e3gnn_oeq_register_autograd(); + #define INTEGER_TYPE torch::TensorOptions().dtype(torch::kInt64) #define FLOAT_TYPE torch::TensorOptions().dtype(torch::kFloat) @@ -562,6 +565,7 @@ void PairE3GNNParallel::coeff(int narg, char **arg) { {"dtype", ""}, {"time", ""}, {"flashTP", "version mismatch"}, + {"oeq", "version mismatch"}, {"comm_size", ""}}; // model loading from input @@ -609,6 +613,11 @@ void PairE3GNNParallel::coeff(int narg, char **arg) { cutoff_square = cutoff * cutoff; + // to make torch::autograd::grad() works + if (meta_dict["oeq"] == "yes") { + pair_e3gnn_oeq_register_autograd(); + } + if (meta_dict["model_type"].compare("E3_equivariant_model") != 0) { error->all(FLERR, "given model type is not E3_equivariant_model"); } @@ -663,6 +672,8 @@ void PairE3GNNParallel::coeff(int narg, char **arg) { meta_dict["dtype"].c_str(), meta_dict["time"].c_str()); fprintf(lmp->logfile, "FlashTP: %s\n", meta_dict["flashTP"].c_str()); + fprintf(lmp->logfile, "OEQ: %s\n", + meta_dict["oeq"].c_str()); } } diff --git a/sevenn/pair_e3gnn/patch_lammps.sh b/sevenn/pair_e3gnn/patch_lammps.sh index 44fbf049..ec46c1ba 100755 --- a/sevenn/pair_e3gnn/patch_lammps.sh +++ b/sevenn/pair_e3gnn/patch_lammps.sh @@ -4,15 +4,18 @@ lammps_root=$1 cxx_standard=$2 # 14, 17 d3_support=$3 # 1, 0 flashTP_so="${4:-NONE}" +oeq_so="${5:-NONE}" SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") +enable_flashTP=0 +enable_oeq=0 ########################################### # Check if the given arguments are valid # ########################################### # Check the number of arguments -if [ "$#" -lt 3 ] || [ "$#" -gt 4 ]; then - echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support} {flashTP_so}" +if [ "$#" -lt 3 ] || [ "$#" -gt 5 ]; then + echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support} {flashTP_so} {oeq_so}" echo " {lammps_root}: Root directory of LAMMPS source" echo " {cxx_standard}: C++ standard (14, 17)" echo " {d3_support}: Support for pair_d3 (1, 0)" @@ -37,14 +40,26 @@ if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then exit 1 fi -if [ "$flashTP_so" != "NONE" ] && [ -f "$flashTP_so" ]; then +if [ -f "$flashTP_so" ]; then echo "Using flashTP_so: $flashTP_so" + enable_flashTP=1 +elif [ "$flashTP_so" = "NONE" ]; then + echo "[FlashTP] Skipped: not provided" else - echo "Invalid or missing flashTP_so given" - echo "[FlashTP] Skipped (not provided)" + echo "[FlashTP] Skipped: Invalid or missing flashTP_so given" flashTP_so="NONE" fi +if [ -f "$oeq_so" ]; then + echo "Using oeq_so: $oeq_so" + enable_oeq=1 +elif [ "$oeq_so" = "NONE" ]; then + echo "[OEQ] Skipped: not provided" +else + echo "[OEQ] Skipped: nvalid or missing oeq_so given" + oeq_so="NONE" +fi + # Check if the patch is already applied if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then echo "----------------------------------------------------------" @@ -106,6 +121,8 @@ cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt # 1. Copy pair_e3gnn files to LAMMPS source cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/ cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/ +# Always copy the oEq autograd bridge (pair_e3gnn.cpp has an extern reference to it) +cp $SCRIPT_DIR/pair_e3gnn_oeq_autograd.cpp $lammps_root/src/ # TODO: set this as oeq-specific # 2. Patch cmake/CMakeLists.txt sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt @@ -116,24 +133,37 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}") EOF -if [ "$flashTP_so" != "NONE" ]; then - -cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" - -find_package(Python3 REQUIRED COMPONENTS Development) -target_link_libraries(lammps PUBLIC -Wl,--no-as-needed - ${CMAKE_CURRENT_LIST_DIR}/flashTP/libflashtp_large_kernel_lammps.so - Python3::Python -) -set_target_properties(lammps PROPERTIES - BUILD_RPATH "${CMAKE_CURRENT_LIST_DIR}/flashTP" -) -EOF +optional_libs=() +optional_rpaths=() -echo "[FlashTP] CMakeLists.txt is patched" +if [ "$enable_flashTP" -eq 1 ]; then + optional_libs+=(" \${CMAKE_CURRENT_LIST_DIR}/flashTP/libflashtp_large_kernel_lammps.so") + optional_rpaths+=("\${CMAKE_CURRENT_LIST_DIR}/flashTP") + mkdir -p $lammps_root/cmake/flashTP && cp $flashTP_so $lammps_root/cmake/flashTP/libflashtp_large_kernel_lammps.so && echo "[FlashTP] flashTP so file copied" +fi -mkdir -p $lammps_root/cmake/flashTP && cp $flashTP_so $lammps_root/cmake/flashTP/libflashtp_large_kernel_lammps.so && echo "[FlashTP] flashTP so file copied" +if [ "$enable_oeq" -eq 1 ]; then + optional_libs+=(" \${CMAKE_CURRENT_LIST_DIR}/oeq/liboeq_stable_cuda_aoti.so") + optional_rpaths+=("\${CMAKE_CURRENT_LIST_DIR}/oeq") + mkdir -p $lammps_root/cmake/oeq && cp $oeq_so $lammps_root/cmake/oeq/liboeq_stable_cuda_aoti.so && echo "[OEQ] oeq so file copied" +fi +if [ ${#optional_libs[@]} -gt 0 ]; then + rpath_joined=$(IFS=';'; echo "${optional_rpaths[*]}") + { + echo "" + echo "find_package(Python3 REQUIRED COMPONENTS Development)" + echo "target_link_libraries(lammps PUBLIC -Wl,--no-as-needed" + for lib in "${optional_libs[@]}"; do + echo "$lib" + done + echo " Python3::Python" + echo ")" + echo "set_property(TARGET lammps APPEND PROPERTY" + echo " BUILD_RPATH \"$rpath_joined\"" + echo ")" + } >> $lammps_root/cmake/CMakeLists.txt + echo "[Optional] CMakeLists.txt is patched" fi @@ -170,11 +200,23 @@ echo "Changes made:" echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups" echo " - Copied contents of pair_e3gnn to $lammps_root/src/" echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard" +echo +if [ "$enable_flashTP" -eq 1 ]; then + echo " - Copied flashTP .so to $lammps_root/cmake/flashTP/" + echo " - Patched CMakeLists.txt: link libflashtp_large_kernel_lammps.so" + echo +fi +if [ "$enable_oeq" -eq 1 ]; then + echo " - Copied oeq .so to $lammps_root/cmake/oeq/" + echo " - Copied pair_e3gnn_oeq_autograd.cpp to $lammps_root/src/" + echo " - Patched CMakeLists.txt: link liboeq_stable_cuda_aoti.so" + echo +fi if [ "$d3_support" -ne 0 ]; then echo " - Copied contents of pair_d3 to $lammps_root/src/" echo " - Patched CMakeLists.txt: include CUDA" + echo fi - # Provide example cmake command to the user echo "Example build commands, under LAMMPS root" echo " mkdir build; cd build" diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index 869e2848..c81a9ad1 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -18,6 +18,7 @@ def deploy( fname='deployed_serial.pt', modal: Optional[str] = None, use_flash: bool = False, + use_oeq: bool = False, ) -> None: from sevenn.nn.edge_embedding import EdgePreprocess from sevenn.nn.force_output import ForceStressOutput @@ -26,7 +27,10 @@ def deploy( model, config = ( cp.build_model( - enable_cueq=False, enable_flash=use_flash, _flash_lammps=use_flash + enable_cueq=False, + enable_flash=use_flash, + enable_oeq=use_oeq, + _flash_lammps=use_flash, ), cp.config, ) @@ -63,6 +67,7 @@ def deploy( md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) md_configs.update({'flashTP': 'yes' if use_flash else 'no'}) + md_configs.update({'oeq': 'yes' if use_oeq else 'no'}) md_configs.update( {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} ) @@ -81,6 +86,7 @@ def deploy_parallel( fname='deployed_parallel', modal: Optional[str] = None, use_flash: bool = False, + use_oeq: bool = False, ) -> None: # Additional layer for ghost atom (and copy parameters from original) GHOST_LAYERS_KEYS = ['onehot_to_feature_x', '0_self_interaction_1'] @@ -89,12 +95,16 @@ def deploy_parallel( model, config = ( cp.build_model( - enable_cueq=False, enable_flash=use_flash, _flash_lammps=use_flash + enable_cueq=False, + enable_flash=use_flash, + enable_oeq=use_oeq, + _flash_lammps=use_flash, ), cp.config, ) config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': False} config[KEY.USE_FLASH_TP] = use_flash + config[KEY.USE_OEQ] = use_oeq config['_flash_lammps'] = use_flash model_state_dct = model.state_dict() @@ -117,7 +127,7 @@ def deploy_parallel( if hasattr(model_part, 'eval_type_map'): setattr(model_part, 'eval_type_map', False) # Ensure all values are inserted - assert len(missing) == 0 or use_flash, missing + assert len(missing) == 0 or use_flash or use_oeq, missing if modal: model_list[0].prepare_modal_deploy(modal) @@ -147,6 +157,7 @@ def deploy_parallel( md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) md_configs.update({'comm_size': str(comm_size)}) md_configs.update({'flashTP': 'yes' if use_flash else 'no'}) + md_configs.update({'oeq': 'yes' if use_oeq else 'no'}) md_configs.update( {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} ) diff --git a/tests/lammps_tests/test_lammps.py b/tests/lammps_tests/test_lammps.py index 314b8504..8847feea 100644 --- a/tests/lammps_tests/test_lammps.py +++ b/tests/lammps_tests/test_lammps.py @@ -17,6 +17,7 @@ from sevenn.model_build import build_E3_equivariant_model from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available +from sevenn.nn.oeq_helper import is_oeq_available from sevenn.scripts.deploy import deploy, deploy_parallel from sevenn.util import chemical_species_preprocess, pretrained_name_to_path @@ -58,6 +59,14 @@ def serial_potential_path_flash(tmp_path_factory): return pot_path +@pytest.fixture(scope='module') +def serial_potential_path_oeq(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential_oeq') + pot_path = str(tmp / 'deployed_serial.pt') + deploy(cp_7net0_path, pot_path, use_oeq=True) + return pot_path + + @pytest.fixture(scope='module') def parallel_potential_path(tmp_path_factory): tmp = tmp_path_factory.mktemp('paralllel_potential') @@ -321,6 +330,26 @@ def test_serial(system, serial_potential_path, ref_calculator, lammps_cmd, tmp_p assert_atoms(atoms, atoms_lammps) +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_oeq( + system, serial_potential_path_oeq, ref_7net0_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=serial_potential_path_oeq, + wd=tmp_path, + test_name='serial oeq lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5) + + @pytest.mark.skipif(not is_flash_available(), reason='flash not available') @pytest.mark.parametrize( 'system', @@ -511,6 +540,70 @@ def test_cueq_parallel(lammps_cmd, mpirun_cmd, tmp_path): assert_atoms(atoms, atoms_lammps) +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +def test_oeq_serial(lammps_cmd, tmp_path): + oeq = True + model = get_model() + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = get_system('bulk') + + cfg = get_model_config() + cfg.update({'version': sevenn.__version__}) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + pot_path = str(tmp_path / 'deployed_oeq_serial.pt') + deploy(cp_path, pot_path, use_oeq=oeq) + + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=pot_path, + wd=tmp_path, + test_name='oeq checkpoint serial lmp run test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_calc + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +def test_oeq_parallel(lammps_cmd, mpirun_cmd, tmp_path): + oeq = True + model = get_model() + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = get_system('surface', replicate=(4, 4, 1)) + + cfg = get_model_config() + cfg.update({'version': sevenn.__version__}) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + pot_path = str(tmp_path / 'deployed_oeq_parallel') + deploy_parallel(cp_path, pot_path, use_oeq=oeq) + + atoms_lammps = parallel_lammps_run( + atoms=atoms, + potential=' '.join([str(cfg['num_convolution_layer']), pot_path]), + wd=tmp_path, + test_name='oeq checkpoint parallel lmp run test', + lammps_cmd=lammps_cmd, + mpirun_cmd=mpirun_cmd, + ncores=2, + ) + atoms.calc = ref_calc + assert_atoms(atoms, atoms_lammps) + + def _get_disconnected_system(name): """Build Atoms object for a disconnected test structure.""" if name.startswith('single_o'): @@ -574,3 +667,24 @@ def test_disconnected_serial_flash( ) atoms.calc = ref_7net0_calculator assert_atoms(atoms, atoms_lammps, atol=1e-5) + + +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize('system', _disconnected_systems) +def test_disconnected_serial_oeq( + system, + serial_potential_path_oeq, + ref_7net0_calculator, + lammps_cmd, + tmp_path, +): + atoms = _get_disconnected_system(system) + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=serial_potential_path_oeq, + wd=tmp_path, + test_name=f'disconnected serial oeq {system}', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5) From 8df70e7755678feb6a9cf826b3f826e867ca24ef Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Tue, 3 Mar 2026 17:57:51 +0900 Subject: [PATCH 03/10] add optional dependencies --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6f856a0e..42259439 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,10 @@ dependencies = [ test = ["pytest", "pytest-cov>=5", "ipython"] torchsim = ["torch-sim-atomistic>=0.5.2; python_version >= '3.12'"] cueq12 = ["cuequivariance>=0.6.0; python_version >= '3.10'", "cuequivariance-torch>=0.6.0; python_version >= '3.10'", "cuequivariance-ops-torch-cu12; python_version >= '3.10'"] +cueq13 = ["cuequivariance>=0.7.0; python_version >= '3.10'", "cuequivariance-torch>=0.7.0; python_version >= '3.10'", "cuequivariance-ops-torch-cu13; python_version >= '3.10'"] oeq = ["openequivariance"] -mliap = ["cython==3.0.11", "cupy-cuda12x"] +mliap12 = ["cython==3.0.11", "cupy-cuda12x"] +mliap13 = ["cython==3.0.11", "cupy-cuda13x"] [project.scripts] sevenn = "sevenn.main.sevenn:main" From a0f7353ce696d23b42ba69915c7fc14e4deeafdc Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Tue, 3 Mar 2026 18:09:50 +0900 Subject: [PATCH 04/10] flake8 --- sevenn/calculator.py | 2 +- sevenn/mliap.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 8ec2b8ea..5b7467d3 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -112,7 +112,7 @@ def __init__( cp = util.load_checkpoint(model) model_loaded = cp.build_model( - enable_cueq=enable_cueq, enable_flash=enable_flash, enable_oeq=enable_oeq + enable_cueq=enable_cueq, enable_flash=enable_flash, enable_oeq=enable_oeq # noqa: E501 ) model_loaded.set_is_batch_data(False) diff --git a/sevenn/mliap.py b/sevenn/mliap.py index 2be00873..513d55c7 100644 --- a/sevenn/mliap.py +++ b/sevenn/mliap.py @@ -151,9 +151,9 @@ def _ensure_model_initialized(self): if self.model is not None: return # Already initialized print('[INFO] Lazy initializing SevenNet model...', flush=True) - print(f'[INFO] cueq={self.use_cueq}, flashTP={self.use_flash}, oeq={self.use_oeq}', flush=True) + print(f'[INFO] cueq={self.use_cueq}, flashTP={self.use_flash}, oeq={self.use_oeq}', flush=True) # noqa: E501 model = self.cp.build_model( - enable_cueq=self.use_cueq, enable_flash=self.use_flash, enable_oeq=self.use_oeq + enable_cueq=self.use_cueq, enable_flash=self.use_flash, enable_oeq=self.use_oeq # noqa: E501 ) for k, module in model._modules.items(): From 5302ea1d9d681274efaa5a55f76afec4ab1cf199 Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Tue, 10 Mar 2026 13:21:16 +0900 Subject: [PATCH 05/10] add oeq docs --- docs/source/user_guide/accelerator.md | 48 ++++++++++++++++++++---- docs/source/user_guide/ase_calculator.md | 6 +-- docs/source/user_guide/cli.md | 10 ++--- docs/source/user_guide/lammps_mliap.md | 8 ++-- docs/source/user_guide/lammps_torch.md | 4 +- 5 files changed, 56 insertions(+), 20 deletions(-) diff --git a/docs/source/user_guide/accelerator.md b/docs/source/user_guide/accelerator.md index d668ebc9..5619324b 100644 --- a/docs/source/user_guide/accelerator.md +++ b/docs/source/user_guide/accelerator.md @@ -3,14 +3,14 @@ This document describes available accelerator integrations in SevenNet and their installation guide. :::{caution} -We do not support CuEquivariance for [LAMMPS: Torch](./lammps_mliap.md). You must use [LAMMPS: ML-IAP](./lammps_torch.md) for CuEquivariance. +We do not support CuEquivariance for [LAMMPS: Torch](./lammps_torch.md). You must use [LAMMPS: ML-IAP](./lammps_mliap.md) for CuEquivariance. ::: -[CuEquivariance](https://github.com/NVIDIA/cuEquivariance) and [FlashTP](https://openreview.net/forum?id=wiQe95BPaB) provide acceleration for both SevenNet training and inference. (For speed, check the section 2.7 of [SevenNet-Omni paper](https://arxiv.org/abs/2510.11241)) +[FlashTP](https://openreview.net/forum?id=wiQe95BPaB), [CuEquivariance](https://github.com/NVIDIA/cuEquivariance), and [OpenEquivariance](https://github.com/PASSIONLab/OpenEquivariance) provide acceleration for both SevenNet training and inference. :::{tip} -For small systems, FlashTP with [LAMMPS: Torch](./lammps_mliap.md) shows performance advantage over cuEquivariance with [LAMMPS: ML-IAP](./lammps_torch.md). -A performance crossover occurs at around 10³ atoms, beyond which cuEquivariance becomes more efficient. +For small systems, FlashTP with [LAMMPS: Torch](./lammps_torch.md) shows performance advantage over cuEquivariance with [LAMMPS: ML-IAP](./lammps_mliap.md). +A performance crossover occurs at around 10³ atoms, beyond which cuEquivariance becomes more efficient. (For more informattion, check the section 2.7 of [SevenNet-Omni paper](https://arxiv.org/abs/2510.11241)) FlashTP with [LAMMPS: Torch](./lammps_mliap.md) is generally faster than FlashTP with [LAMMPS: ML-IAP](./lammps_mliap.md). ::: @@ -24,9 +24,16 @@ CuEquivariance is an NVIDIA Python library designed to facilitate the constructi - cuEquivariance >= 0.6.1 ### Installation -After installation of SevenNet, install cuEquivariance following their guideline. -To use cuEquivariance with SevenNet, you need to install `cuequivariance`, `cuequivariance-torch`, and `cuequivariance-ops-torch-cu{12, 13}` (depending on your CUDA version) -[cuEquivariance](https://github.com/NVIDIA/cuEquivariance). +After installation of SevenNet, install cuEquivariance via the pip extras: + +```bash +pip install sevenn[cueq12] # For CUDA 12.x +pip install sevenn[cueq13] # For CUDA 13.x +``` + +This will install `cuequivariance`, `cuequivariance-torch`, and the corresponding `cuequivariance-ops-torch-cu{12,13}`. + +Alternatively, you can install cuEquivariance manually following [their guideline](https://github.com/NVIDIA/cuEquivariance). :::{note} Some GeForce GPUs do not support `pynvml`, @@ -79,8 +86,33 @@ True For more information, see [FlashTP](https://github.com/SNU-ARC/flashTP). +## [OpenEquivariance (experimental)](https://github.com/PASSIONLab/OpenEquivariance) + +:::{caution} +OpenEquivariance support in SevenNet is currently experimental. +::: + +OpenEquivariance (oEq) is a library for acceleration of equivariant tensor products. Unlike cuEquivariance, its kernels are compiled once per irreps configuration, eliminating runtime recompilation overhead during training on diverse datasets. + +### Requirements +- Python >= 3.10 +- openequivariance + +### Installation +```bash +pip install sevenn[oeq] +``` + +Check your installation: +```bash +python -c 'from sevenn.nn.oeq_helper import is_oeq_available; print(is_oeq_available())' +True +``` + +For more information, see [OpenEquivariance](https://github.com/PASSIONLab/OpenEquivariance). + ## Usage -After the installation, you can leverage the accelerator with appropriate flag (`--enable_cueq`) or (`--enable_flash`) options +After the installation, you can leverage the accelerator with appropriate flag (`--enable_cueq`), (`--enable_flash`), or (`--enable_oeq`) options - [Training](./cli.md#sevenn-train) - [ASE Calculator](./ase_calculator.md) diff --git a/docs/source/user_guide/ase_calculator.md b/docs/source/user_guide/ase_calculator.md index 2a44c638..4d3d1c4c 100644 --- a/docs/source/user_guide/ase_calculator.md +++ b/docs/source/user_guide/ase_calculator.md @@ -14,11 +14,11 @@ from sevenn.calculator import SevenNetD3Calculator calc = SevenNetD3Calculator(model='7net-0', device='cuda') ``` -Use enable_cueq or enable_flash to use cuEquivariance or flashTP for faster inference. -For more information about cuEq and flashTP, follow [here](./accelerator.md) +Use enable_cueq, enable_flash, or enable_oeq to use cuEquivariance, flashTP, or OpenEquivariance for faster inference. +For more information about accelerators, follow [here](./accelerator.md) ```python from sevenn.calculator import SevenNetCalculator -calc = SevenNetCalculator(model='7net-0', enable_cueq=True) # or enable_flash=True +calc = SevenNetCalculator(model='7net-0', enable_cueq=True) # or enable_flash=True or enable_oeq=True ``` If you encounter the error `CUDA is not installed or nvcc is not available`, please ensure the `nvcc` compiler is available. Currently, CPU + D3 is not supported. diff --git a/docs/source/user_guide/cli.md b/docs/source/user_guide/cli.md index 153f45d9..06c971a9 100644 --- a/docs/source/user_guide/cli.md +++ b/docs/source/user_guide/cli.md @@ -44,10 +44,10 @@ We support multi-GPU training using PyTorch DDP (distributed data parallel) with torchrun --standalone --nnodes {number of nodes} --nproc_per_node {number of GPUs} --no_python sevenn input.yaml -d ``` -Enable cuEquivariance or flashTP while training by adding the --enable_cueq or --enable_flash. +Enable cuEquivariance, flashTP, or OpenEquivariance while training by adding the --enable_cueq, --enable_flash, or --enable_oeq. ```bash -sevenn train input.yaml -s --enable_cueq # or --enable_flash +sevenn train input.yaml -s --enable_cueq # or --enable_flash or --enable_oeq ``` Please note that `batch_size` in `input.yaml` refers to the per-GPU batch size. @@ -79,19 +79,19 @@ See {doc}`accelerator` for installation of accelerators. sevenn get_model \ {pretrained_name or checkpoint_path} \ {--use_mliap} \ # For LAMMPS ML-IAP use. - {--enable_flash | --enable_cueq} \ # For accelerators. + {--enable_flash | --enable_cueq | --enable_oeq} \ # For accelerators. {--modal {task_name}} \ # Required when using multi-fidelity model {--get_parallel} # For parallel MD simulations ``` If `--use_mliap` is not set (TorchScript version), it will create `deployed_serial.pt` or a directory containing several `deployed_parallel_*.pt` files for parallel execution. They can be used as a LAMMPS potential with the `e3gnn` or `e3gnn/parallel` pair_style in LAMMPS. -In this case, only `--enable_flash` is available. +In this case, `--enable_flash` and `--enable_oeq` are available. Check {doc}`lammps_torch` for installation and lammps script in this use case. If `--use_mliap` is set, it will create `deployed_serial_mliap.pt`. The file can be used with the `mliap` pair_style in LAMMPS. -In this case, both `--enable_cueq` and `--enable_flash` are available, but you cannot specify both at the same time. +In this case, `--enable_cueq`, `--enable_flash`, and `--enable_oeq` are available, but you cannot specify multiple at the same time. Parallel execution is not tested in this case. Check {doc}`lammps_mliap` for installation and lammps script in this use case. diff --git a/docs/source/user_guide/lammps_mliap.md b/docs/source/user_guide/lammps_mliap.md index d1b4b743..4b564295 100644 --- a/docs/source/user_guide/lammps_mliap.md +++ b/docs/source/user_guide/lammps_mliap.md @@ -6,13 +6,15 @@ Currently the parallel implementation of LAMMPS/ML-IAP is not tested. ## Requirements - cython == 3.0.11 -- cupy-cuda12x +- cupy-cuda12x or cupy-cuda13x - flashTP (optional, follow [here](accelerator.md#flashtp)) - cuEquivariance (optional, follow [here](accelerator.md#cuequivariance)) +- OpenEquivariance (optional, follow [here](accelerator.md#openequivariance)) Install via: ```bash -pip install sevenn[mliap] +pip install sevenn[mliap12] # For CUDA 12.x +pip install sevenn[mliap13] # For CUDA 13.x ``` ## Build @@ -119,7 +121,7 @@ Please check [sevenn graph_build](./cli.md#sevenn-graph-build) for detail. An ML-IAP potential checkpoint can be deployed using ``sevenn get_model`` command with ``--use_mliap`` flag. - By default, output file name will be ``deployed_serial_mliap.pt``. (You can customize the output file name using ``--output_prefix`` flag.) -- You can accelerate the inference with ``--enable_cueq`` or ``--enable_flash`` flag: +- You can accelerate the inference with ``--enable_cueq``, ``--enable_flash``, or ``--enable_oeq`` flag: ```bash sevenn get_model \ {pretrained_name or checkpoint_path} \ diff --git a/docs/source/user_guide/lammps_torch.md b/docs/source/user_guide/lammps_torch.md index a6d50659..8f2425f9 100644 --- a/docs/source/user_guide/lammps_torch.md +++ b/docs/source/user_guide/lammps_torch.md @@ -16,13 +16,15 @@ Ensure the LAMMPS version is `stable_2Aug2023_update3`. You can easily switch th ```bash git clone https://github.com/lammps/lammps.git lammps_sevenn --branch stable_2Aug2023_update3 --depth=1 -sevenn patch_lammps ./lammps_sevenn {--enable_flash} {--d3} +sevenn patch_lammps ./lammps_sevenn {--enable_flash} {--enable_oeq} {--d3} ``` You can refer to `sevenn/pair_e3gnn/patch_lammps.sh` for details of the patch process. :::{tip} (Optional) Add `--enable_flash` option to accelerate SevenNet for LAMMPS using flashTP. You must preinstall [flashTP](accelerator.md#flashtp) before building LAMMPS with flashTP. + +(Optional) Add `--enable_oeq` option to accelerate SevenNet for LAMMPS using OpenEquivariance. You must preinstall [OpenEquivariance](accelerator.md#openequivariance) before building LAMMPS with OpenEquivariance. ::: ### (Optional) Build with GPU-D3 pair style From 22da5af1df2582d29cef28ab9c78839d0f0ca7a3 Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Tue, 10 Mar 2026 13:43:26 +0900 Subject: [PATCH 06/10] add oeq flag in torchsim docs --- docs/source/user_guide/torchsim.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/user_guide/torchsim.md b/docs/source/user_guide/torchsim.md index b1626266..005f6b54 100644 --- a/docs/source/user_guide/torchsim.md +++ b/docs/source/user_guide/torchsim.md @@ -24,6 +24,12 @@ from sevenn.torchsim import SevenNetModel model = SevenNetModel(model="7net-omni", modal="mpa") ``` +You can enable accelerators (cuEquivariance, flashTP, or OpenEquivariance) via the corresponding flags. For more information about accelerators, follow [here](./accelerator.md). +```python +model = SevenNetModel(model="7net-omni", modal="mpa", enable_oeq=True) +# or enable_cueq=True or enable_flash=True +``` + The `device` parameter defaults to `'auto'` (CUDA if available, otherwise CPU). ### Batched MD From d227526574011e8dc1881262e7cfe2ca904f570e Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Tue, 10 Mar 2026 13:44:56 +0900 Subject: [PATCH 07/10] force only on TP accelerator when build model --- sevenn/calculator.py | 34 ++++++++++++++++++++++++---------- sevenn/checkpoint.py | 3 +++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 5b7467d3..2f4a3d59 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -251,12 +251,16 @@ def __init__( model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', file_type: str = 'checkpoint', device: Union[torch.device, str] = 'auto', - sevennet_config: Optional[Any] = None, # hold meta information + modal: Optional[str] = None, + enable_cueq: Optional[bool] = False, + enable_flash: Optional[bool] = False, + enable_oeq: Optional[bool] = False, + sevennet_config: Optional[Any] = None, damping_type: str = 'damp_bj', functional_name: str = 'pbe', vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au - **kwargs, + **kwargs, # pass extra kwargs to both calculators ) -> None: """Initialize SevenNetD3Calculator. CUDA required. @@ -270,21 +274,27 @@ def __init__( device: str | torch.device, default='auto' if not given, use CUDA if available modal: str | None, default=None - modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, - for 7net-omni, use one of 'mpa', 'omat24', 'matpes_pbe', 'matpes_r2scan', - 'mp_r2scan', 'oc20', 'oc22', 'odac23', 'omol25_low', 'omol25_high', - 'spice', 'qcml', 'pet_mad' - it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) + modal (fidelity) if given model is multi-modal model. + for 7net-mf-ompa, it should be one of 'mpa' or 'omat24'. + for 7net-omni, use one of 'mpa', 'omat24', 'matpes_pbe', + 'matpes_r2scan', 'mp_r2scan', 'oc20', 'oc22', 'odac23', + 'omol25_low', 'omol25_high', 'spice', 'qcml', 'pet_mad' enable_cueq: bool, default=False - if True, use cuEquivariant to accelerate inference. + if True, use cuEquivariance backend. + enable_flash: bool, default=False + if True, use flashTP backend. + enable_oeq: bool, default=False + if True, use OpenEquivariance backend. + sevennet_config: dict | None, default=None + hold meta information damping_type: str, default='damp_bj' Damping type of D3, one of 'damp_bj' | 'damp_zero' functional_name: str, default='pbe' Target functional name of D3 parameters. vdw_cutoff: float, default=9000 - vdw cutoff of D3 calculator in au + vdw cutoff of D3 calculator in au^2 cn_cutoff: float, default=1600 - cn cutoff of D3 calculator in au + cn cutoff of D3 calculator in au^2 """ d3_calc = D3Calculator( damping_type=damping_type, @@ -298,6 +308,10 @@ def __init__( model=model, file_type=file_type, device=device, + modal=modal, + enable_cueq=enable_cueq, + enable_flash=enable_flash, + enable_oeq=enable_oeq, sevennet_config=sevennet_config, **kwargs, ) diff --git a/sevenn/checkpoint.py b/sevenn/checkpoint.py index a87e1212..e0422ec2 100644 --- a/sevenn/checkpoint.py +++ b/sevenn/checkpoint.py @@ -332,6 +332,9 @@ def build_model( cp_using_oeq = self.config.get(KEY.USE_OEQ, False) enable_oeq = cp_using_oeq if enable_oeq is None else enable_oeq + if sum([enable_cueq, enable_flash, enable_oeq]) > 1: + raise ValueError('Only one TP accelerator can be enabled.') + assert not _flash_lammps or enable_flash cfg_new = self.config cfg_new['_flash_lammps'] = _flash_lammps From 80a163820e5e117091620d14586a33b58409297a Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Tue, 10 Mar 2026 14:02:20 +0900 Subject: [PATCH 08/10] fix typo --- tests/unit_tests/test_torchsim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_torchsim.py b/tests/unit_tests/test_torchsim.py index 6f92a6b1..3fb75495 100644 --- a/tests/unit_tests/test_torchsim.py +++ b/tests/unit_tests/test_torchsim.py @@ -77,7 +77,7 @@ def test_sevenn_model_dtype_validation( """Test that SevenNetModel raises an error if dtype is not float32.""" with pytest.raises( ValueError, - match='SevenNetModel currently only supports torch.float32, but received ' + match='SevenNet currently only supports torch.float32, but received ' + 'different dtype: torch.float64', ): _ = SevenNetModel( From d16fd79d6f6b7936a9a0fea088dcde85617090be Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Tue, 10 Mar 2026 14:48:20 +0900 Subject: [PATCH 09/10] add oeq option in torchsim --- sevenn/torchsim.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sevenn/torchsim.py b/sevenn/torchsim.py index 8425e53f..d7b7211c 100644 --- a/sevenn/torchsim.py +++ b/sevenn/torchsim.py @@ -74,8 +74,9 @@ def __init__( *, # force remaining arguments to be keyword-only modal: str | None = None, neighbor_list_fn: Callable | None = None, - enable_flash: bool = False, enable_cueq: bool = False, + enable_flash: bool = False, + enable_oeq: bool = False, device: torch.device | str = 'auto', dtype: torch.dtype = torch.float32, ) -> None: @@ -94,6 +95,7 @@ def __init__( 'omat24' (OMat24). enable_cueq (bool): Enable cuEquivariance backend. enable_flash (bool): Enable flashTP backend. + enable_oeq (bool): Enable OpenEquivariance backend. neighbor_list_fn (Callable): Neighbor list function to use. Default is torch_nl_linked_cell. device (torch.device | str): Device to run the model on @@ -133,6 +135,7 @@ def __init__( model = cp.build_model( enable_flash=enable_flash, enable_cueq=enable_cueq, + enable_oeq=enable_oeq, ) _validate(model, modal) From 2586ef3e006f7f40e816f6aa9001ca2bca6a85cb Mon Sep 17 00:00:00 2001 From: alphalm4 Date: Sun, 15 Mar 2026 18:22:17 +0900 Subject: [PATCH 10/10] add changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94173f2c..c55aa959 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. ### Added - Support OpenEquivariance +### Changed +- **[Breaking]** Rename optional dependency group `mliap` into `mliap12` (reflecting its CUDA 12.x dependency). +- Add `cueq13` and `mliap13` optional dependency groups for CUDA 13.x. + ## [0.12.1] ### Fixed - FlashTP with LAMMPS parallel in torch