Skip to content

Bug for fine-tuning the OMat24 model #248

@jinlhr542

Description

@jinlhr542

I am trying to fine-tuning the checkpoint_sevennet_mf_ompa.pth model:

import sevenn.util as util
model, config = util.model_from_checkpoint('checkpoint_sevennet_mf_ompa.pth')
cutoff = config['cutoff'] 
dataset = SevenNetGraphDataset(cutoff=cutoff, root=working_dir, files=dataset_files, processed_name='train.pt')

from sevenn.train.trainer import Trainer
import torch.optim.lr_scheduler as scheduler

trainer = Trainer.from_config(model, config)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 4
      1 from sevenn.train.trainer import Trainer
      2 import torch.optim.lr_scheduler as scheduler
----> 4 trainer = Trainer.from_config(model, config)
      6 # We have energy, force, stress loss function, which used to train 7net-0.
      7 # We will use it as it is, with loss weight: 1.0, 1.0, and 0.01 for energy, force, and stress, respectively.
      8 print(trainer.loss_functions)

File [~/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/trainer.py:88](http://localhost:3416/lab/tree/DRX/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/trainer.py#line=87), in Trainer.from_config(model, config)
     84 @staticmethod
     85 def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer':
     86     trainer = Trainer(
     87         model,
---> 88         loss_functions=get_loss_functions_from_config(config),
     89         optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()],
     90         optimizer_args=config.get(KEY.OPTIM_PARAM, {}),
     91         scheduler_cls=scheduler_dict[
     92             config.get(KEY.SCHEDULER, 'exponentiallr').lower()
     93         ],
     94         scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}),
     95         device=config.get(KEY.DEVICE, 'auto'),
     96         distributed=config.get(KEY.IS_DDP, False),
     97         distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'),
     98     )
     99     return trainer

File [~/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/loss.py:211](http://localhost:3416/lab/tree/DRX/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/loss.py#line=210), in get_loss_functions_from_config(config)
    207 from sevenn.train.optim import loss_dict
    209 loss_functions = []  # list of tuples (loss_definition, weight)
--> 211 loss = loss_dict[config[KEY.LOSS].lower()]
    212 loss_param = config.get(KEY.LOSS_PARAM, {})
    214 use_weight = config.get(KEY.USE_WEIGHT, False)

AttributeError: 'dict' object has no attribute 'lower'

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions