Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions default_hypers/default_hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ ARCHITECTURAL_HYPERS:
K_CUT: None # should be float; only used when USE_LONG_RANGE is True
K_CUT_DELTA: None
DTYPE: float32 # float32 or float16 or bfloat16
N_TARGETS: 1
TARGET_INDEX_KEY: target_index
RESIDUAL_FACTOR: 0.5

FITTING_SCHEME:

FITTING_SCHEME:
INITIAL_LR: 1e-4
EPOCH_NUM_ATOMIC: 1000000000000000000 # structural version is called "EPOCH_NUM"
SCHEDULER_STEP_SIZE_ATOMIC: 500000000 # structural version is called "SCHEDULER_STEP_SIZE"
Expand All @@ -44,6 +47,7 @@ FITTING_SCHEME:
RANDOM_SEED: 0
CUDA_DETERMINISTIC: False
MODEL_TO_START_WITH: None
ALL_SPECIES_PATH: None # should be specified if model_to_start_with is not None; path to all_species file
SUPPORT_MISSING_VALUES: False
USE_WEIGHT_DECAY: False
WEIGHT_DECAY: 0.0
Expand All @@ -58,7 +62,7 @@ MLIP_SETTINGS: # only used when fitting MLIP
USE_ENERGIES: True
USE_FORCES: True

GENERAL_TARGET_SETTINGS: # only used when fitting general target
GENERAL_TARGET_SETTINGS: # only used when fitting general target
TARGET_TYPE: structural
TARGET_AGGREGATION: sum # sum or mean; only used when TARGET_TYPE is structural
TARGET_DIM: 42
Expand Down
40 changes: 40 additions & 0 deletions example/h2_train.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-8.759742673137072 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 1.53208638 0.13098867 4.53832253
H 0.76553186 0.00000000 0.00000000 -1.16199225 -1.58021080 0.33524214
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-3.7192079015352757 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 2.73044949 0.06874556 0.21296946
H 1.00466482 0.00000000 0.00000000 3.11014793 0.53780466 -3.79237256
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=7.189273706327455 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 1.53916231 4.99280004 4.67277815
H 0.90433764 0.00000000 0.00000000 4.82751479 -3.72742755 1.53408770
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=8.17870662429138 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 3.66308968 3.87160588 4.82455171
H 0.95810929 0.00000000 0.00000000 -2.04872371 0.41517506 -2.14607990
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-5.993965811098398 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 3.96966993 2.27459901 2.29572060
H 0.89272778 0.00000000 0.00000000 4.13658605 2.63264531 0.03607745
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=5.5484447205560645 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 -1.03319915 3.52230195 3.02470110
H 0.87884678 0.00000000 0.00000000 3.71634012 4.60398225 4.93817840
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-5.814588656353306 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 4.38672742 -3.38098568 3.96265227
H 1.10390324 0.00000000 0.00000000 -2.36787831 1.26632275 -2.54980052
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=4.822978863120337 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 -1.16551747 4.41462459 -2.12527669
H 0.92116485 0.00000000 0.00000000 1.53299624 -4.04118132 -3.07972468
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=2.153278030275324 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 -1.50090037 2.51211053 4.84955928
H 1.15947263 0.00000000 0.00000000 -1.38685321 4.87367691 1.93581327
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-9.691070544370257 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 -2.43225987 1.17215292 0.49919348
H 1.05785635 0.00000000 0.00000000 0.84314954 -0.43226929 4.92230042
40 changes: 40 additions & 0 deletions example/h2_val.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-8.759742673137072 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 1.53208638 0.13098867 4.53832253
H 0.76553186 0.00000000 0.00000000 -1.16199225 -1.58021080 0.33524214
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-3.7192079015352757 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 2.73044949 0.06874556 0.21296946
H 1.00466482 0.00000000 0.00000000 3.11014793 0.53780466 -3.79237256
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=7.189273706327455 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 1.53916231 4.99280004 4.67277815
H 0.90433764 0.00000000 0.00000000 4.82751479 -3.72742755 1.53408770
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=8.17870662429138 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 3.66308968 3.87160588 4.82455171
H 0.95810929 0.00000000 0.00000000 -2.04872371 0.41517506 -2.14607990
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-5.993965811098398 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 3.96966993 2.27459901 2.29572060
H 0.89272778 0.00000000 0.00000000 4.13658605 2.63264531 0.03607745
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=5.5484447205560645 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 -1.03319915 3.52230195 3.02470110
H 0.87884678 0.00000000 0.00000000 3.71634012 4.60398225 4.93817840
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-5.814588656353306 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 4.38672742 -3.38098568 3.96265227
H 1.10390324 0.00000000 0.00000000 -2.36787831 1.26632275 -2.54980052
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=4.822978863120337 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 -1.16551747 4.41462459 -2.12527669
H 0.92116485 0.00000000 0.00000000 1.53299624 -4.04118132 -3.07972468
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=2.153278030275324 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 -1.50090037 2.51211053 4.84955928
H 1.15947263 0.00000000 0.00000000 -1.38685321 4.87367691 1.93581327
2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-9.691070544370257 pbc="F F F"
H 0.00000000 0.00000000 0.00000000 -2.43225987 1.17215292 0.49919348
H 1.05785635 0.00000000 0.00000000 0.84314954 -0.43226929 4.92230042
12 changes: 12 additions & 0 deletions example/hypers_finetuning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
ARCHITECTURAL_HYPERS:
R_CUT: 100
N_TRANS_LAYERS: 2
N_GNN_LAYERS: 2

FITTING_SCHEME:
EPOCH_NUM: 10
EPOCHS_WARMUP: 0
MODEL_TO_START_WITH: results/pretraining/best_val_mae_both_model_state_dict
ALL_SPECIES_PATH: results/pretraining/all_species.npy
MLIP_SETTINGS:
USE_ENERGIES: True
4 changes: 3 additions & 1 deletion src/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def get_pyg_graphs(
USE_ADDITIONAL_SCALAR_ATTRIBUTES,
USE_LONG_RANGE,
K_CUT,
MULTI_TARGET,
TARGET_INDEX_KEY
):
molecules = [
Molecule(
structure, R_CUT, USE_ADDITIONAL_SCALAR_ATTRIBUTES, USE_LONG_RANGE, K_CUT
structure, R_CUT, USE_ADDITIONAL_SCALAR_ATTRIBUTES, USE_LONG_RANGE, K_CUT, MULTI_TARGET, TARGET_INDEX_KEY
)
for structure in tqdm(structures)
]
Expand Down
2 changes: 2 additions & 0 deletions src/estimate_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def main():
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
ARCHITECTURAL_HYPERS.K_CUT,
ARCHITECTURAL_HYPERS.N_TARGETS > 1,
ARCHITECTURAL_HYPERS.TARGET_INDEX_KEY
)

if FITTING_SCHEME.MULTI_GPU:
Expand Down
7 changes: 7 additions & 0 deletions src/hypers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def combine_hypers(provided_hypers, default_hypers):
raise ValueError(
"if using balanced_data_loader only atomic batch size can be provided"
)

if result["FITTING_SCHEME"]["MODEL_TO_START_WITH"] is not None:
if result["FITTING_SCHEME"]["ALL_SPECIES_PATH"] is None:
raise ValueError(
"if using model_to_start_with, all_species_path must be provided"
)

return result


Expand Down
26 changes: 22 additions & 4 deletions src/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class Molecule:
def __init__(
self, atoms, r_cut, use_additional_scalar_attributes, use_long_range, k_cut
self, atoms, r_cut, use_additional_scalar_attributes, use_long_range, k_cut, multi_target, target_index_key
):

self.use_additional_scalar_attributes = use_additional_scalar_attributes
Expand Down Expand Up @@ -75,6 +75,11 @@ def is_same(first, second):
self.k_cut = k_cut
self.volume = get_volume(self.cell[0], self.cell[1], self.cell[2])

if multi_target:
self.target_index = int(self.atoms.info[target_index_key])
else:
self.target_index = None

def get_max_num(self):
maximum = None
for chunk in self.relative_positions:
Expand Down Expand Up @@ -151,6 +156,9 @@ def get_graph(self, max_num, all_species, max_num_k):
"n_atoms": len(self.atoms.positions),
}

if self.target_index is not None:
kwargs['target_id'] = self.target_index

if self.use_additional_scalar_attributes:
kwargs["neighbor_scalar_attributes"] = torch.tensor(
neighbor_scalar_attributes, dtype=torch.get_default_dtype()
Expand Down Expand Up @@ -190,7 +198,7 @@ def get_graph(self, max_num, all_species, max_num_k):

class MoleculeCPP:
def __init__(
self, atoms, r_cut, use_additional_scalar_attributes, use_long_range, k_cut
self, atoms, r_cut, use_additional_scalar_attributes, use_long_range, k_cut, multi_target, target_index_key
):

self.use_additional_scalar_attributes = use_additional_scalar_attributes
Expand All @@ -203,7 +211,7 @@ def __init__(
raise NotImplementedError("Long range is not implemented in cpp")
if self.use_additional_scalar_attributes:
raise NotImplementedError("Additional scalar attributes are not implemented in cpp")

def is_3d_crystal(atoms):
pbc = atoms.get_pbc()
if isinstance(pbc, bool):
Expand All @@ -227,9 +235,14 @@ def is_3d_crystal(atoms):
else:
self.max_num = torch.max(torch.bincount(self.i_list))

if multi_target:
self.target_index = int(self.atoms.info[target_index_key])
else:
self.target_index = None

def get_num_k(self):
raise NotImplementedError("Long range is not implemented in cpp")

def get_max_num(self):
return self.max_num

Expand All @@ -251,6 +264,9 @@ def get_graph(self, max_num, all_species, max_num_k):
"n_atoms": len(self.atoms.positions),
}

if self.target_index is not None:
kwargs['target_id'] = self.target_index

result = Data(**kwargs)

return result
Expand All @@ -266,6 +282,8 @@ def batch_to_dict(batch):
"neighbors_index": batch.neighbors_index.transpose(0, 1),
"neighbors_pos": batch.neighbors_pos,
}
if hasattr(batch, 'target_id'):
batch_dict['target_id'] = batch.target_id

if hasattr(batch, "neighbor_scalar_attributes"):
batch_dict["neighbor_scalar_attributes"] = batch.neighbor_scalar_attributes
Expand Down
Loading
Loading