From c6e5aa54088b9dd632c4faab9106f7954afb307d Mon Sep 17 00:00:00 2001 From: sgoodlett Date: Mon, 18 Aug 2025 16:03:55 +0200 Subject: [PATCH 1/2] Integraging multifidelity models --- peslearn/ml/__init__.py | 2 + peslearn/ml/mfnn/__init__.py | 9 + peslearn/ml/mfnn/delta.py | 10 + peslearn/ml/mfnn/dual.py | 142 ++++++++++ peslearn/ml/mfnn/mknn.py | 218 +++++++++++++++ peslearn/ml/mfnn/weight_transfer.py | 416 ++++++++++++++++++++++++++++ 6 files changed, 797 insertions(+) create mode 100644 peslearn/ml/mfnn/__init__.py create mode 100644 peslearn/ml/mfnn/delta.py create mode 100644 peslearn/ml/mfnn/dual.py create mode 100644 peslearn/ml/mfnn/mknn.py create mode 100644 peslearn/ml/mfnn/weight_transfer.py diff --git a/peslearn/ml/__init__.py b/peslearn/ml/__init__.py index 4c33d1c..48ac608 100644 --- a/peslearn/ml/__init__.py +++ b/peslearn/ml/__init__.py @@ -4,6 +4,8 @@ from . import preprocessing_helper from . import model +from . import mfnn + from .gaussian_process import GaussianProcess from .data_sampler import DataSampler from .neural_network import NeuralNetwork diff --git a/peslearn/ml/mfnn/__init__.py b/peslearn/ml/mfnn/__init__.py new file mode 100644 index 0000000..d8dfcf5 --- /dev/null +++ b/peslearn/ml/mfnn/__init__.py @@ -0,0 +1,9 @@ +from . import dual +from . import delta +from . import weight_transfer +from . import mknn + +from .dual import DualNN +from .delta import DeltaNN +from .weight_transfer import WTNN +from .mknn import MKNN diff --git a/peslearn/ml/mfnn/delta.py b/peslearn/ml/mfnn/delta.py new file mode 100644 index 0000000..1365b69 --- /dev/null +++ b/peslearn/ml/mfnn/delta.py @@ -0,0 +1,10 @@ +from copy import deepcopy +from ..neural_network import NeuralNetwork + +class DeltaNN(NeuralNetwork): + def __init__(self, dataset_path, input_obj, molecule_type=None, molecule=None, train_path=None, test_path=None, valid_path=None): + super().__init__(dataset_path, input_obj, molecule_type, molecule, train_path, test_path, valid_path) + lf_E = self.raw_X[:,-1].reshape(-1,1) + self.raw_X = self.raw_X[:,:-1] + self.raw_y = deepcopy(self.raw_y) - lf_E # If modified in place (i.e. self.raw_y -= lf_E) then PES.dat will be modified to delta rather than HF_E + \ No newline at end of file diff --git a/peslearn/ml/mfnn/dual.py b/peslearn/ml/mfnn/dual.py new file mode 100644 index 0000000..b604d25 --- /dev/null +++ b/peslearn/ml/mfnn/dual.py @@ -0,0 +1,142 @@ +import torch +import numpy as np +from ..neural_network import NeuralNetwork +import os +from copy import deepcopy +from ...constants import package_directory +from ..preprocessing_helper import morse, interatomics_to_fundinvar, degree_reduce, general_scaler +from sklearn.model_selection import train_test_split + +torch.set_printoptions(precision=15) + +class DualNN(NeuralNetwork): + def __init__(self, dataset_path, input_obj, molecule_type=None, molecule=None, train_path=None, test_path=None, valid_path=None): + #super().__init__(dataset_path, input_obj, molecule_type, molecule, train_path, test_path, valid_path) + super().__init__(dataset_path, input_obj, molecule_type, molecule, train_path, test_path, valid_path) + self.trial_layers = self.input_obj.keywords['nas_trial_layers'] + self.set_default_hyperparameters() + + if self.input_obj.keywords['validation_points']: + self.nvalid = self.input_obj.keywords['validation_points'] + if (self.nvalid + self.ntrain + 1) > self.n_datapoints: + raise Exception("Error: User-specified training set size and validation set size exceeds the size of the dataset.") + else: + self.nvalid = round((self.n_datapoints - self.ntrain) / 2) + + if self.pip: + if molecule_type: + path = os.path.join(package_directory, "lib", molecule_type, "output") + self.inp_dim = len(open(path).readlines())+1 + if molecule: + path = os.path.join(package_directory, "lib", molecule.molecule_type, "output") + self.inp_dim = len(open(path).readlines())+1 + else: + self.inp_dim = self.raw_X.shape[1] + + def split_train_test(self, params, validation_size=None, precision=32): + self.X, self.y, self.Xscaler, self.yscaler, self.lf_E_scaler = self.preprocess(params, self.raw_X, self.raw_y) + if self.sampler == 'user_supplied': + self.Xtr = self.transform_new_X(self.raw_Xtr, params, self.Xscaler) + self.ytr = self.transform_new_y(self.raw_ytr, self.yscaler) + self.Xtest = self.transform_new_X(self.raw_Xtest, params, self.Xscaler) + self.ytest = self.transform_new_y(self.raw_ytest, self.yscaler) + if self.valid_path: + self.Xvalid = self.transform_new_X(self.raw_Xvalid, params, self.Xscaler) + self.yvalid = self.transform_new_y(self.raw_yvalid, self.yscaler) + else: + raise Exception("Please provide a validation set for Neural Network training.") + else: + self.Xtr = self.X[self.train_indices] + self.ytr = self.y[self.train_indices] + #TODO: this is splitting validation data in the same way at every model build, not necessary. + self.valid_indices, self.new_test_indices = train_test_split(self.test_indices, train_size = validation_size, random_state=42) + if validation_size: + self.Xvalid = self.X[self.valid_indices] + self.yvalid = self.y[self.valid_indices] + self.Xtest = self.X[self.new_test_indices] + self.ytest = self.y[self.new_test_indices] + + else: + raise Exception("Please specify a validation set size for Neural Network training.") + + # convert to Torch Tensors + if precision == 32: + self.Xtr = torch.tensor(self.Xtr, dtype=torch.float32) + self.ytr = torch.tensor(self.ytr, dtype=torch.float32) + self.Xtest = torch.tensor(self.Xtest, dtype=torch.float32) + self.ytest = torch.tensor(self.ytest, dtype=torch.float32) + self.Xvalid = torch.tensor(self.Xvalid,dtype=torch.float32) + self.yvalid = torch.tensor(self.yvalid,dtype=torch.float32) + self.X = torch.tensor(self.X,dtype=torch.float32) + self.y = torch.tensor(self.y,dtype=torch.float32) + elif precision == 64: + self.Xtr = torch.tensor(self.Xtr, dtype=torch.float64) + self.ytr = torch.tensor(self.ytr, dtype=torch.float64) + self.Xtest = torch.tensor(self.Xtest, dtype=torch.float64) + self.ytest = torch.tensor(self.ytest, dtype=torch.float64) + self.Xvalid = torch.tensor(self.Xvalid,dtype=torch.float64) + self.yvalid = torch.tensor(self.yvalid,dtype=torch.float64) + self.X = torch.tensor(self.X,dtype=torch.float64) + self.y = torch.tensor(self.y,dtype=torch.float64) + else: + raise Exception("Invalid option for 'precision'") + + def preprocess(self, params, raw_X_less, raw_y): + """ + Preprocess raw data according to hyperparameters + """ + lf_E = deepcopy(raw_X_less[:,-1].reshape(-1,1)) + raw_X = deepcopy(raw_X_less[:,:-1]) + if params['morse_transform']['morse']: + raw_X = morse(raw_X, params['morse_transform']['morse_alpha']) + if params['pip']['pip']: + # find path to fundamental invariants form molecule type AxByCz... + path = os.path.join(package_directory, "lib", self.molecule_type, "output") + #lf_E = raw_X[:,-1] + raw_X, degrees = interatomics_to_fundinvar(raw_X,path) + #raw_X = np.hstack((raw_X, lf_E[:,None])) + if params['pip']['degree_reduction']: + #raw_X[:,:-1] = degree_reduce(raw_X[:,:-1], degrees) + raw_X = degree_reduce(raw_X, degrees) + if params['scale_X']: + X, Xscaler = general_scaler(params['scale_X']['scale_X'], raw_X) + else: + X = raw_X + Xscaler = None + if params['scale_y']: + lf_E, lf_E_scaler = general_scaler(params['scale_y'], lf_E) + y, yscaler = general_scaler(params['scale_y'], raw_y) + else: + lf_E_scaler = None + y = raw_y + yscaler = None + X = np.hstack((X, lf_E)) + #X = np.hstack((X, lf_E[:,None])) + return X, y, Xscaler, yscaler, lf_E_scaler + + def transform_new_X(self, newX, params, Xscaler=None, lf_E_scaler=None): + """ + Transform a new, raw input according to the model's transformation procedure + so that prediction can be made. + """ + # ensure X dimension is n x m (n new points, m input variables) + if len(newX.shape) == 1: + newX = np.expand_dims(newX,0) + elif len(newX.shape) > 2: + raise Exception("Dimensions of input data is incorrect.") + newX_geom = newX[:,:-1] + lf_E = newX[:,-1].reshape(-1,1) + if params['morse_transform']['morse']: + newX_geom = morse(newX_geom, params['morse_transform']['morse_alpha']) + if params['pip']['pip']: + # find path to fundamental invariants for an N atom system with molecule type AxByCz... + path = os.path.join(package_directory, "lib", self.molecule_type, "output") + newX_geom, degrees = interatomics_to_fundinvar(newX_geom,path) + if params['pip']['degree_reduction']: + newX_geom = degree_reduce(newX_geom, degrees) + if Xscaler: + newX_geom = Xscaler.transform(newX_geom) + if lf_E_scaler: + lf_E = lf_E_scaler.transform(lf_E) + #lf_E = lf_E.reshape(-1,1) + return np.hstack((newX_geom, lf_E)) diff --git a/peslearn/ml/mfnn/mknn.py b/peslearn/ml/mfnn/mknn.py new file mode 100644 index 0000000..98e810b --- /dev/null +++ b/peslearn/ml/mfnn/mknn.py @@ -0,0 +1,218 @@ +import numpy as np +from .weight_transfer import WTNN +import torch +import torch.nn as nn +from collections import OrderedDict +from ...constants import hartree2cm +import copy + +class MKNNModel(nn.Module): + def __init__(self, inp_dim, layers, activ) -> None: + super(MKNNModel, self).__init__() + + depth = len(layers) + structure_lf = OrderedDict([('input', nn.Linear(inp_dim, layers[0])), + ('activ_in' , activ)]) + self.model_lf = nn.Sequential(structure_lf) + for i in range(depth-1): + self.model_lf.add_module('layer' + str(i), nn.Linear(layers[i], layers[i+1])) + self.model_lf.add_module('activ' + str(i), activ) + self.model_lf.add_module('output', nn.Linear(layers[depth-1], 1)) + + #structure_hf = OrderedDict([('input', nn.Linear(inp_dim+1, layers[0])), + # ('activ_in' , activ)]) # Add one to inp_dim for LF energy + #self.nonlinear_hf = nn.Sequential(structure_hf) # Nonlinear NN for HF prediction + #for i in range(depth-1): + # self.nonlinear_hf.add_module('layer' + str(i), nn.Linear(layers[i], layers[i+1])) + # self.nonlinear_hf.add_module('activ' + str(i), activ) + #self.nonlinear_hf.add_module('output', nn.Linear(layers[depth-1], 1)) + self.nonlinear_hf = nn.Sequential( + nn.Linear(inp_dim+1,32), + nn.Tanh(), + nn.Linear(32,32), + nn.Tanh(), + nn.Linear(32,32), + nn.Tanh(), + nn.Linear(32,1), + nn.Tanh()) + + self.linear_hf = nn.Linear(inp_dim+1,1) # Linear NN + + def forward(self, xh, xl): + yl = self.model_lf(xl) + yl_xh = self.model_lf(xh) + #print(xh.shape) + #print(yl_xh.shape) + hin = torch.cat((xh,yl_xh), dim=1) + nliny = self.nonlinear_hf(hin) + liny = self.linear_hf(hin) + yh = liny + nliny + return yh, yl + + +class MKNN(WTNN): + def __init__(self, dataset_path, dataset_path_lf, input_obj, input_obj_lf, molecule_type=None, molecule=None, train_path=None, test_path=None, valid_path=None): + super().__init__(dataset_path, dataset_path_lf, input_obj, input_obj_lf, molecule_type, molecule, train_path, test_path, valid_path) + + def build_model(self, params, maxit=1000, val_freq=10, es_patience=2, opt='lbfgs', tol=1, decay=False, verbose=False, precision=32, return_model=False): + print("Hyperparameters: ", params) + self.split_train_test(params, validation_size=self.nvalid, validation_size_lf=self.nvalid_lf, precision=precision) # split data, according to scaling hp's + scale = params['scale_y'] # Find descaling factor to convert loss to original energy units + if scale == 'std': + loss_descaler = self.yscaler.var_[0] + if scale.startswith('mm'): + loss_descaler = (1/self.yscaler.scale_[0]**2) + + activation = params['scale_X']['activation'] + if activation == 'tanh': + activ = nn.Tanh() + if activation == 'sigmoid': + activ = nn.Sigmoid() + + inp_dim = self.inp_dim + l = params['layers'] + torch.manual_seed(0) + + model = MKNNModel(inp_dim, l, activ) + + if precision == 64: # cast model to proper precision + model = model.double() + + metric = torch.nn.MSELoss() + # Define optimizer + if 'lr' in params: + lr = params['lr'] + elif opt == 'lbfgs': + lr = 0.5 + else: + lr = 0.1 + + optimizer = self.get_optimizer(opt, model.parameters(), lr=lr) + #optimizer = torch.optim.Adam(model.parameters(), lr=lr*0.01) + # Define update variables for early stopping, decay, gradient explosion handling + prev_loss = 1.0 + es_tracker = 0 + best_val_error = None + failures = 0 + decay_attempts = 0 + prev_best = None + decay_start = False + maxit += 5000 + labda = 1e-6 #l2_norm = sum(p.pow(2.0).sum() for p in model.parameters()) + for epoch in range(1,maxit): + def closure(): + optimizer.zero_grad() + y_pred_hf, y_pred_lf = model(self.Xtr, self.Xtr_lf) + loss = torch.sqrt(metric(y_pred_lf, self.ytr_lf)) + torch.sqrt(metric(y_pred_hf, self.ytr)) + labda*sum(p.pow(2.0).sum() for p in model.parameters()) # L2 regularization + loss.backward() + return loss + optimizer.step(closure) + # validate + if epoch % val_freq == 0: + with torch.no_grad(): + tmp_pred, trash = model(self.Xvalid, self.Xvalid) + tmp_loss = metric(tmp_pred, self.yvalid) + val_error_rmse = np.sqrt(tmp_loss.item() * loss_descaler) * hartree2cm # loss_descaler converts MSE in scaled data domain to MSE in unscaled data domain + if best_val_error: + if val_error_rmse < best_val_error: + prev_best = best_val_error * 1.0 + best_val_error = val_error_rmse * 1.0 + else: + record = True + best_val_error = val_error_rmse * 1.0 + prev_best = best_val_error + if verbose: + print("Epoch {} Validation RMSE (cm-1): {:5.3f}".format(epoch, val_error_rmse)) + if decay_start: + scheduler.step(val_error_rmse) + + # Early Stopping + if epoch > 5: + # if current validation error is not the best (current - best > 0) and is within tol of previous error, the model is stagnant. + if ((val_error_rmse - prev_loss) < tol) and (val_error_rmse - best_val_error) > 0.0: + es_tracker += 1 + # else if: current validation error is not the best (current - best > 0) and is greater than the best by tol, the model is overfitting. Bad epoch. + elif ((val_error_rmse - best_val_error) > tol) and (val_error_rmse - best_val_error) > 0.0: + es_tracker += 1 + # else if: if the current validation error is a new record, but not significant, the model is stagnant + elif (prev_best - best_val_error) < 0.001: + es_tracker += 1 + # else: model set a new record validation error. Reset early stopping tracker + else: + es_tracker = 0 + #TODO this framework does not detect oscillatory behavior about 'tol', though this has not been observed to occur in any case + # Check status of early stopping tracker. First try decaying to see if stagnation can be resolved, if not then terminate training + if es_tracker > es_patience: + if decay: # if decay is set to true, if early stopping criteria is triggered, begin LR scheduler and go back to previous model state and attempt LR decay. + if decay_attempts < 1: + decay_attempts += 1 + es_tracker = 0 + if verbose: + print("Performance plateau detected. Reverting model state and decaying learning rate.") + decay_start = True + thresh = (0.1 / np.sqrt(loss_descaler)) / hartree2cm # threshold is 0.1 wavenumbers + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, threshold=thresh, threshold_mode='abs', min_lr=0.05, cooldown=2, patience=10, verbose=verbose) + model.load_state_dict(saved_model_state_dict) + saved_optimizer_state_dict['param_groups'][0]['lr'] = lr*0.9 + optimizer.load_state_dict(saved_optimizer_state_dict) + # Since learning rate is decayed, override tolerance, patience, validation frequency for high-precision + #tol = 0.05 + #es_patience = 100 + #val_freq = 1 + continue + else: + prev_loss = val_error_rmse * 1.0 + if verbose: + print('Early stopping termination') + break + else: + prev_loss = val_error_rmse * 1.0 + if verbose: + print('Early stopping termination') + break + + # Handle exploding gradients + if epoch > 10: + if (val_error_rmse > prev_loss*10): # detect large increases in loss + if epoch > 60: # distinguish between exploding gradients at near converged models and early on exploding grads + if verbose: + print("Exploding gradient detected. Resuming previous model state and decaying learning rate") + model.load_state_dict(saved_model_state_dict) + saved_optimizer_state_dict['param_groups'][0]['lr'] = lr*0.5 + optimizer.load_state_dict(saved_optimizer_state_dict) + failures += 1 # if + if failures > 2: + break + else: + continue + else: + break + if val_error_rmse != val_error_rmse: # detect NaN + break + if ((prev_loss < 1.0) and (precision == 32)): # if 32 bit precision and model is giving very high accuracy, kill so the accuracy does not go beyond 32 bit precision + break + prev_loss = val_error_rmse * 1.0 # save previous loss to track improvement + + # Periodically save model state so we can reset under instability/overfitting/performance plateau + if epoch % 50 == 0: + saved_model_state_dict = copy.deepcopy(model.state_dict()) + saved_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + with torch.no_grad(): + train_pred, trash = model(self.Xtr, self.Xtr) + train_loss = metric(train_pred, self.ytr) + train_error_rmse = np.sqrt(train_loss.item() * loss_descaler) * hartree2cm + test_pred, trash = model(self.Xtest, self.Xtest) + test_loss = metric(test_pred, self.ytest) + test_error_rmse = np.sqrt(test_loss.item() * loss_descaler) * hartree2cm + val_pred, trash = model(self.Xvalid, self.Xvalid) + val_loss = metric(val_pred, self.yvalid) + val_error_rmse = np.sqrt(val_loss.item() * loss_descaler) * hartree2cm + full_pred, trash = model(self.X, self.X) + full_loss = metric(full_pred, self.y) + full_error_rmse = np.sqrt(full_loss.item() * loss_descaler) * hartree2cm + print("Test set RMSE (cm-1): {:5.2f} Validation set RMSE (cm-1): {:5.2f} Train set RMSE: {:5.2f} Full dataset RMSE (cm-1): {:5.2f}".format(test_error_rmse, val_error_rmse, train_error_rmse, full_error_rmse)) + if return_model: + return model, test_error_rmse, val_error_rmse, full_error_rmse + else: + return test_error_rmse, val_error_rmse diff --git a/peslearn/ml/mfnn/weight_transfer.py b/peslearn/ml/mfnn/weight_transfer.py new file mode 100644 index 0000000..306596a --- /dev/null +++ b/peslearn/ml/mfnn/weight_transfer.py @@ -0,0 +1,416 @@ +from ..neural_network import NeuralNetwork +from ..model import Model +import torch +import torch.nn as nn +from collections import OrderedDict +from sklearn.model_selection import train_test_split +from ...constants import hartree2cm +import copy +import numpy as np + +class WTNN(NeuralNetwork): + def __init__(self, dataset_path, dataset_path_lf, input_obj, input_obj_lf, molecule_type=None, molecule=None, train_path=None, test_path=None, valid_path=None): + super().__init__(dataset_path, input_obj, molecule_type, molecule, train_path, test_path, valid_path) + self.lf_model = NeuralNetwork(dataset_path_lf, input_obj_lf, molecule_type, molecule) + #self.lf_model = Model(dataset_path_lf, input_obj_lf, molecule_type, molecule, train_path, test_path, valid_path) # TODO: Train, test, valid paths are for HF model + if self.lf_model.input_obj.keywords['validation_points']: + self.nvalid_lf = self.lf_model.input_obj.keywords['validation_points'] + if (self.nvalid_lf + self.lf_model.ntrain + 1) > self.lf_model.n_datapoints: + raise Exception("Error: User-specified training set size and validation set size exceeds the size of the dataset.") + else: + self.nvalid_lf = round((self.lf_model.n_datapoints - self.lf_model.ntrain) / 2) + + def split_train_test(self, params, validation_size=None, validation_size_lf=None, precision=32): + self.X, self.y, self.Xscaler, self.yscaler = self.preprocess(params, self.raw_X, self.raw_y) + self.X_lf, self.y_lf, self.Xscaler_lf, self.yscaler_lf = self.preprocess(params, self.lf_model.raw_X, self.lf_model.raw_y) + if self.sampler == 'user_supplied': + self.Xtr = self.transform_new_X(self.raw_Xtr, params, self.Xscaler) + self.ytr = self.transform_new_y(self.raw_ytr, self.yscaler) + self.Xtest = self.transform_new_X(self.raw_Xtest, params, self.Xscaler) + self.ytest = self.transform_new_y(self.raw_ytest, self.yscaler) + + self.Xtr_lf = self.transform_new_X(self.lf_model.raw_Xtr, params, self.Xscaler_lf) + self.ytr_lf = self.transform_new_y(self.lf_model.raw_ytr, self.yscaler_lf) + self.Xtest_lf = self.transform_new_X(self.lf_model.raw_Xtest, params, self.Xscaler_lf) + self.ytest_lf = self.transform_new_y(self.lf_model.raw_ytest, self.yscaler_lf) + if self.valid_path: + self.Xvalid = self.transform_new_X(self.raw_Xvalid, params, self.Xscaler) + self.yvalid = self.transform_new_y(self.raw_yvalid, self.yscaler) + + self.Xvalid_lf = self.transform_new_X(self.lf_model.raw_Xvalid, params, self.Xscaler_lf) + self.yvalid_lf = self.transform_new_y(self.lf_model.raw_yvalid, self.yscaler_lf) + else: + raise Exception("Please provide a validation set for Neural Network training.") + else: + self.Xtr = self.X[self.train_indices] + self.ytr = self.y[self.train_indices] + + self.Xtr_lf = self.X_lf[self.lf_model.train_indices] + self.ytr_lf = self.y_lf[self.lf_model.train_indices] + #TODO: this is splitting validation data in the same way at every model build, not necessary. + self.valid_indices, self.new_test_indices = train_test_split(self.test_indices, train_size = validation_size, random_state=42) + self.valid_indices_lf, self.new_test_indices_lf = train_test_split(self.lf_model.test_indices, train_size = validation_size_lf, random_state=42) + if validation_size and validation_size_lf: + self.Xvalid = self.X[self.valid_indices] + self.yvalid = self.y[self.valid_indices] + self.Xtest = self.X[self.new_test_indices] + self.ytest = self.y[self.new_test_indices] + + self.Xvalid_lf = self.X_lf[self.valid_indices_lf] + self.yvalid_lf = self.y_lf[self.valid_indices_lf] + self.Xtest_lf = self.X_lf[self.new_test_indices_lf] + self.ytest_lf = self.y_lf[self.new_test_indices_lf] + + else: + raise Exception("Please specify a validation set size for Neural Network training.") + + # convert to Torch Tensors + if precision == 32: + self.Xtr = torch.tensor(self.Xtr, dtype=torch.float32) + self.ytr = torch.tensor(self.ytr, dtype=torch.float32) + self.Xtest = torch.tensor(self.Xtest, dtype=torch.float32) + self.ytest = torch.tensor(self.ytest, dtype=torch.float32) + self.Xvalid = torch.tensor(self.Xvalid,dtype=torch.float32) + self.yvalid = torch.tensor(self.yvalid,dtype=torch.float32) + self.X = torch.tensor(self.X,dtype=torch.float32) + self.y = torch.tensor(self.y,dtype=torch.float32) + + self.Xtr_lf = torch.tensor(self.Xtr_lf, dtype=torch.float32) + self.ytr_lf = torch.tensor(self.ytr_lf, dtype=torch.float32) + self.Xtest_lf = torch.tensor(self.Xtest_lf, dtype=torch.float32) + self.ytest_lf = torch.tensor(self.ytest_lf, dtype=torch.float32) + self.Xvalid_lf = torch.tensor(self.Xvalid_lf,dtype=torch.float32) + self.yvalid_lf = torch.tensor(self.yvalid_lf,dtype=torch.float32) + self.X_lf = torch.tensor(self.X_lf,dtype=torch.float32) + self.y_lf = torch.tensor(self.y_lf,dtype=torch.float32) + elif precision == 64: + self.Xtr = torch.tensor(self.Xtr, dtype=torch.float64) + self.ytr = torch.tensor(self.ytr, dtype=torch.float64) + self.Xtest = torch.tensor(self.Xtest, dtype=torch.float64) + self.ytest = torch.tensor(self.ytest, dtype=torch.float64) + self.Xvalid = torch.tensor(self.Xvalid,dtype=torch.float64) + self.yvalid = torch.tensor(self.yvalid,dtype=torch.float64) + self.X = torch.tensor(self.X,dtype=torch.float64) + self.y = torch.tensor(self.y,dtype=torch.float64) + + self.Xtr_lf = torch.tensor(self.Xtr_lf, dtype=torch.float64) + self.ytr_lf = torch.tensor(self.ytr_lf, dtype=torch.float64) + self.Xtest_lf = torch.tensor(self.Xtest_lf, dtype=torch.float64) + self.ytest_lf = torch.tensor(self.ytest_lf, dtype=torch.float64) + self.Xvalid_lf = torch.tensor(self.Xvalid_lf,dtype=torch.float64) + self.yvalid_lf = torch.tensor(self.yvalid_lf,dtype=torch.float64) + self.X_lf = torch.tensor(self.X_lf,dtype=torch.float64) + self.y_lf = torch.tensor(self.y_lf,dtype=torch.float64) + else: + raise Exception("Invalid option for 'precision'") + + def build_model(self, params, maxit=1000, val_freq=10, es_patience=2, opt='lbfgs', tol=1, decay=False, verbose=False, precision=32, return_model=False): + print("Training LF model:") + model, lf_test_error, lf_val_error, lf_full_error = self.lf_model.build_model(params, maxit, val_freq, es_patience, opt, tol, decay, verbose, precision,return_model=True) + """ + # LF Training + print("Hyperparameters: ", params) + self.split_train_test(params, validation_size=self.nvalid, validation_size_lf=self.nvalid_lf, precision=precision) # split data, according to scaling hp's + scale = params['scale_y'] # Find descaling factor to convert loss to original energy units + + if scale == 'std': + loss_descaler = self.yscaler_lf.var_[0] # Here + if scale.startswith('mm'): + loss_descaler = (1/self.yscaler_lf.scale_[0]**2) # Here + + activation = params['scale_X']['activation'] + if activation == 'tanh': + activ = nn.Tanh() + if activation == 'sigmoid': + activ = nn.Sigmoid() + + inp_dim = self.inp_dim + l = params['layers'] + torch.manual_seed(0) + depth = len(l) + structure = OrderedDict([('input', nn.Linear(inp_dim, l[0])), + ('activ_in' , activ)]) + + model = nn.Sequential(structure) # Here + for i in range(depth-1): + model.add_module('layer' + str(i), nn.Linear(l[i], l[i+1])) + model.add_module('activ' + str(i), activ) + model.add_module('output', nn.Linear(l[depth-1], 1)) + if precision == 64: # cast model to proper precision + model = model.double() + metric = torch.nn.MSELoss() + + # Define optimizer + if 'lr' in params: + lr = params['lr'] + elif opt == 'lbfgs': + lr = 0.5 + else: + lr = 0.1 + + optimizer = self.get_optimizer(opt, model.parameters(), lr=lr) + # Define update variables for early stopping, decay, gradient explosion handling + prev_loss = 1.0 + es_tracker = 0 + best_val_error = None + failures = 0 + decay_attempts = 0 + prev_best = None + decay_start = False + for epoch in range(1,maxit): + def closure(): + optimizer.zero_grad() + y_pred = model(self.Xtr_lf) + loss = torch.sqrt(metric(y_pred, self.ytr_lf)) # passing RMSE instead of MSE improves precision IMMENSELY + loss.backward() + return loss + optimizer.step(closure) + # validate + if epoch % val_freq == 0: + with torch.no_grad(): + tmp_pred = model(self.Xvalid_lf) + tmp_loss = metric(tmp_pred, self.yvalid_lf) + val_error_rmse = np.sqrt(tmp_loss.item() * loss_descaler) * hartree2cm # loss_descaler converts MSE in scaled data domain to MSE in unscaled data domain + if best_val_error: + if val_error_rmse < best_val_error: + prev_best = best_val_error * 1.0 + best_val_error = val_error_rmse * 1.0 + else: + record = True + best_val_error = val_error_rmse * 1.0 + prev_best = best_val_error + if verbose: + print("Epoch {} Validation RMSE (cm-1): {:5.3f}".format(epoch, val_error_rmse)) + if decay_start: + scheduler.step(val_error_rmse) + + # Early Stopping + if epoch > 5: + # if current validation error is not the best (current - best > 0) and is within tol of previous error, the model is stagnant. + if ((val_error_rmse - prev_loss) < tol) and (val_error_rmse - best_val_error) > 0.0: + es_tracker += 1 + # else if: current validation error is not the best (current - best > 0) and is greater than the best by tol, the model is overfitting. Bad epoch. + elif ((val_error_rmse - best_val_error) > tol) and (val_error_rmse - best_val_error) > 0.0: + es_tracker += 1 + # else if: if the current validation error is a new record, but not significant, the model is stagnant + elif (prev_best - best_val_error) < 0.001: + es_tracker += 1 + # else: model set a new record validation error. Reset early stopping tracker + else: + es_tracker = 0 + #TODO this framework does not detect oscillatory behavior about 'tol', though this has not been observed to occur in any case + # Check status of early stopping tracker. First try decaying to see if stagnation can be resolved, if not then terminate training + if es_tracker > es_patience: + if decay: # if decay is set to true, if early stopping criteria is triggered, begin LR scheduler and go back to previous model state and attempt LR decay. + if decay_attempts < 1: + decay_attempts += 1 + es_tracker = 0 + if verbose: + print("Performance plateau detected. Reverting model state and decaying learning rate.") + decay_start = True + thresh = (0.1 / np.sqrt(loss_descaler)) / hartree2cm # threshold is 0.1 wavenumbers + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, threshold=thresh, threshold_mode='abs', min_lr=0.05, cooldown=2, patience=10, verbose=verbose) + model.load_state_dict(saved_model_state_dict) + saved_optimizer_state_dict['param_groups'][0]['lr'] = lr*0.9 + optimizer.load_state_dict(saved_optimizer_state_dict) + # Since learning rate is decayed, override tolerance, patience, validation frequency for high-precision + #tol = 0.05 + #es_patience = 100 + #val_freq = 1 + continue + else: + prev_loss = val_error_rmse * 1.0 + if verbose: + print('Early stopping termination') + break + else: + prev_loss = val_error_rmse * 1.0 + if verbose: + print('Early stopping termination') + break + + # Handle exploding gradients + if epoch > 10: + if (val_error_rmse > prev_loss*10): # detect large increases in loss + if epoch > 60: # distinguish between exploding gradients at near converged models and early on exploding grads + if verbose: + print("Exploding gradient detected. Resuming previous model state and decaying learning rate") + model.load_state_dict(saved_model_state_dict) + saved_optimizer_state_dict['param_groups'][0]['lr'] = lr*0.5 + optimizer.load_state_dict(saved_optimizer_state_dict) + failures += 1 # if + if failures > 2: + break + else: + continue + else: + break + if val_error_rmse != val_error_rmse: # detect NaN + break + if ((prev_loss < 1.0) and (precision == 32)): # if 32 bit precision and model is giving very high accuracy, kill so the accuracy does not go beyond 32 bit precision + break + prev_loss = val_error_rmse * 1.0 # save previous loss to track improvement + + # Periodically save model state so we can reset under instability/overfitting/performance plateau + if epoch % 50 == 0: + saved_model_state_dict = copy.deepcopy(model.state_dict()) + saved_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + with torch.no_grad(): + test_pred = model(self.Xtest_lf) + test_loss = metric(test_pred, self.ytest_lf) + test_error_rmse = np.sqrt(test_loss.item() * loss_descaler) * hartree2cm + val_pred = model(self.Xvalid_lf) + val_loss = metric(val_pred, self.yvalid_lf) + val_error_rmse = np.sqrt(val_loss.item() * loss_descaler) * hartree2cm + full_pred = model(self.X_lf) + full_loss = metric(full_pred, self.y_lf) + full_error_rmse = np.sqrt(full_loss.item() * loss_descaler) * hartree2cm + """ + + # HF Training + self.split_train_test(params, validation_size=self.nvalid, validation_size_lf=self.nvalid_lf, precision=precision) # split data, according to scaling hp's + scale = params['scale_y'] # Find descaling factor to convert loss to original energy units + if scale == 'std': + loss_descaler = self.yscaler.var_[0] + if scale.startswith('mm'): + loss_descaler = (1/self.yscaler.scale_[0]**2) + + # Define update variables for early stopping, decay, gradient explosion handling + prev_loss = 1.0 + es_tracker = 0 + best_val_error = None + failures = 0 + decay_attempts = 0 + prev_best = None + decay_start = False + + # Define optimizer + if 'lr' in params: + lr = params['lr'] + elif opt == 'lbfgs': + lr = 0.5 + else: + lr = 0.1 + optimizer = self.get_optimizer(opt, model.parameters(), lr=lr) + metric = torch.nn.MSELoss() + + saved_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + saved_optimizer_state_dict['param_groups'][0]['lr'] = lr * 0.1 + optimizer.load_state_dict(saved_optimizer_state_dict) + for epoch in range(1,maxit): + def closure(): + optimizer.zero_grad() + y_pred = model(self.Xtr) + loss = torch.sqrt(metric(y_pred, self.ytr)) # passing RMSE instead of MSE improves precision IMMENSELY + loss.backward() + return loss + optimizer.step(closure) + # validate + if epoch % val_freq == 0: + with torch.no_grad(): + tmp_pred = model(self.Xvalid) + tmp_loss = metric(tmp_pred, self.yvalid) + val_error_rmse = np.sqrt(tmp_loss.item() * loss_descaler) * hartree2cm # loss_descaler converts MSE in scaled data domain to MSE in unscaled data domain + if best_val_error: + if val_error_rmse < best_val_error: + prev_best = best_val_error * 1.0 + best_val_error = val_error_rmse * 1.0 + else: + record = True + best_val_error = val_error_rmse * 1.0 + prev_best = best_val_error + if verbose: + print("Epoch {} Validation RMSE (cm-1): {:5.3f}".format(epoch, val_error_rmse)) + if decay_start: + scheduler.step(val_error_rmse) + + # Early Stopping + if epoch > 5: + # if current validation error is not the best (current - best > 0) and is within tol of previous error, the model is stagnant. + if ((val_error_rmse - prev_loss) < tol) and (val_error_rmse - best_val_error) > 0.0: + es_tracker += 1 + # else if: current validation error is not the best (current - best > 0) and is greater than the best by tol, the model is overfitting. Bad epoch. + elif ((val_error_rmse - best_val_error) > tol) and (val_error_rmse - best_val_error) > 0.0: + es_tracker += 1 + # else if: if the current validation error is a new record, but not significant, the model is stagnant + elif (prev_best - best_val_error) < 0.001: + es_tracker += 1 + # else: model set a new record validation error. Reset early stopping tracker + else: + es_tracker = 0 + #TODO this framework does not detect oscillatory behavior about 'tol', though this has not been observed to occur in any case + # Check status of early stopping tracker. First try decaying to see if stagnation can be resolved, if not then terminate training + if es_tracker > es_patience: + if decay: # if decay is set to true, if early stopping criteria is triggered, begin LR scheduler and go back to previous model state and attempt LR decay. + if decay_attempts < 1: + decay_attempts += 1 + es_tracker = 0 + if verbose: + print("Performance plateau detected. Reverting model state and decaying learning rate.") + decay_start = True + thresh = (0.1 / np.sqrt(loss_descaler)) / hartree2cm # threshold is 0.1 wavenumbers + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, threshold=thresh, threshold_mode='abs', min_lr=0.05, cooldown=2, patience=10, verbose=verbose) + model.load_state_dict(saved_model_state_dict) + saved_optimizer_state_dict['param_groups'][0]['lr'] = lr*0.9 + optimizer.load_state_dict(saved_optimizer_state_dict) + # Since learning rate is decayed, override tolerance, patience, validation frequency for high-precision + #tol = 0.05 + #es_patience = 100 + #val_freq = 1 + continue + else: + prev_loss = val_error_rmse * 1.0 + if verbose: + print('Early stopping termination') + break + else: + prev_loss = val_error_rmse * 1.0 + if verbose: + print('Early stopping termination') + break + + # Handle exploding gradients + if epoch > 10: + if (val_error_rmse > prev_loss*10): # detect large increases in loss + if epoch > 60: # distinguish between exploding gradients at near converged models and early on exploding grads + if verbose: + print("Exploding gradient detected. Resuming previous model state and decaying learning rate") + model.load_state_dict(saved_model_state_dict) + saved_optimizer_state_dict['param_groups'][0]['lr'] = lr*0.5 + optimizer.load_state_dict(saved_optimizer_state_dict) + failures += 1 # if + if failures > 2: + break + else: + continue + else: + break + if val_error_rmse != val_error_rmse: # detect NaN + break + if ((prev_loss < 1.0) and (precision == 32)): # if 32 bit precision and model is giving very high accuracy, kill so the accuracy does not go beyond 32 bit precision + break + prev_loss = val_error_rmse * 1.0 # save previous loss to track improvement + + # Periodically save model state so we can reset under instability/overfitting/performance plateau + if epoch % 50 == 0: + saved_model_state_dict = copy.deepcopy(model.state_dict()) + saved_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + with torch.no_grad(): + test_pred = model(self.Xtest) + test_loss = metric(test_pred, self.ytest) + test_error_rmse = np.sqrt(test_loss.item() * loss_descaler) * hartree2cm + val_pred = model(self.Xvalid) + val_loss = metric(val_pred, self.yvalid) + val_error_rmse = np.sqrt(val_loss.item() * loss_descaler) * hartree2cm + full_pred = model(self.X) + full_loss = metric(full_pred, self.y) + full_error_rmse = np.sqrt(full_loss.item() * loss_descaler) * hartree2cm + print("HF: Test set RMSE (cm-1): {:5.2f} Validation set RMSE (cm-1): {:5.2f} Full dataset RMSE (cm-1): {:5.2f}".format(test_error_rmse, val_error_rmse, full_error_rmse)) + + if return_model: + return model, test_error_rmse, val_error_rmse, full_error_rmse + else: + return test_error_rmse, val_error_rmse + From 196c2d329a0fcfdc8a632208c5a9c11f03d16250 Mon Sep 17 00:00:00 2001 From: sgoodlett Date: Mon, 18 Aug 2025 16:29:35 +0200 Subject: [PATCH 2/2] Small clean up --- peslearn/ml/mfnn/mknn.py | 9 -- peslearn/ml/mfnn/weight_transfer.py | 160 ---------------------------- 2 files changed, 169 deletions(-) diff --git a/peslearn/ml/mfnn/mknn.py b/peslearn/ml/mfnn/mknn.py index 98e810b..36f03a9 100644 --- a/peslearn/ml/mfnn/mknn.py +++ b/peslearn/ml/mfnn/mknn.py @@ -19,13 +19,6 @@ def __init__(self, inp_dim, layers, activ) -> None: self.model_lf.add_module('activ' + str(i), activ) self.model_lf.add_module('output', nn.Linear(layers[depth-1], 1)) - #structure_hf = OrderedDict([('input', nn.Linear(inp_dim+1, layers[0])), - # ('activ_in' , activ)]) # Add one to inp_dim for LF energy - #self.nonlinear_hf = nn.Sequential(structure_hf) # Nonlinear NN for HF prediction - #for i in range(depth-1): - # self.nonlinear_hf.add_module('layer' + str(i), nn.Linear(layers[i], layers[i+1])) - # self.nonlinear_hf.add_module('activ' + str(i), activ) - #self.nonlinear_hf.add_module('output', nn.Linear(layers[depth-1], 1)) self.nonlinear_hf = nn.Sequential( nn.Linear(inp_dim+1,32), nn.Tanh(), @@ -41,8 +34,6 @@ def __init__(self, inp_dim, layers, activ) -> None: def forward(self, xh, xl): yl = self.model_lf(xl) yl_xh = self.model_lf(xh) - #print(xh.shape) - #print(yl_xh.shape) hin = torch.cat((xh,yl_xh), dim=1) nliny = self.nonlinear_hf(hin) liny = self.linear_hf(hin) diff --git a/peslearn/ml/mfnn/weight_transfer.py b/peslearn/ml/mfnn/weight_transfer.py index 306596a..0ed8b1b 100644 --- a/peslearn/ml/mfnn/weight_transfer.py +++ b/peslearn/ml/mfnn/weight_transfer.py @@ -107,166 +107,6 @@ def split_train_test(self, params, validation_size=None, validation_size_lf=None def build_model(self, params, maxit=1000, val_freq=10, es_patience=2, opt='lbfgs', tol=1, decay=False, verbose=False, precision=32, return_model=False): print("Training LF model:") model, lf_test_error, lf_val_error, lf_full_error = self.lf_model.build_model(params, maxit, val_freq, es_patience, opt, tol, decay, verbose, precision,return_model=True) - """ - # LF Training - print("Hyperparameters: ", params) - self.split_train_test(params, validation_size=self.nvalid, validation_size_lf=self.nvalid_lf, precision=precision) # split data, according to scaling hp's - scale = params['scale_y'] # Find descaling factor to convert loss to original energy units - - if scale == 'std': - loss_descaler = self.yscaler_lf.var_[0] # Here - if scale.startswith('mm'): - loss_descaler = (1/self.yscaler_lf.scale_[0]**2) # Here - - activation = params['scale_X']['activation'] - if activation == 'tanh': - activ = nn.Tanh() - if activation == 'sigmoid': - activ = nn.Sigmoid() - - inp_dim = self.inp_dim - l = params['layers'] - torch.manual_seed(0) - depth = len(l) - structure = OrderedDict([('input', nn.Linear(inp_dim, l[0])), - ('activ_in' , activ)]) - - model = nn.Sequential(structure) # Here - for i in range(depth-1): - model.add_module('layer' + str(i), nn.Linear(l[i], l[i+1])) - model.add_module('activ' + str(i), activ) - model.add_module('output', nn.Linear(l[depth-1], 1)) - if precision == 64: # cast model to proper precision - model = model.double() - metric = torch.nn.MSELoss() - - # Define optimizer - if 'lr' in params: - lr = params['lr'] - elif opt == 'lbfgs': - lr = 0.5 - else: - lr = 0.1 - - optimizer = self.get_optimizer(opt, model.parameters(), lr=lr) - # Define update variables for early stopping, decay, gradient explosion handling - prev_loss = 1.0 - es_tracker = 0 - best_val_error = None - failures = 0 - decay_attempts = 0 - prev_best = None - decay_start = False - for epoch in range(1,maxit): - def closure(): - optimizer.zero_grad() - y_pred = model(self.Xtr_lf) - loss = torch.sqrt(metric(y_pred, self.ytr_lf)) # passing RMSE instead of MSE improves precision IMMENSELY - loss.backward() - return loss - optimizer.step(closure) - # validate - if epoch % val_freq == 0: - with torch.no_grad(): - tmp_pred = model(self.Xvalid_lf) - tmp_loss = metric(tmp_pred, self.yvalid_lf) - val_error_rmse = np.sqrt(tmp_loss.item() * loss_descaler) * hartree2cm # loss_descaler converts MSE in scaled data domain to MSE in unscaled data domain - if best_val_error: - if val_error_rmse < best_val_error: - prev_best = best_val_error * 1.0 - best_val_error = val_error_rmse * 1.0 - else: - record = True - best_val_error = val_error_rmse * 1.0 - prev_best = best_val_error - if verbose: - print("Epoch {} Validation RMSE (cm-1): {:5.3f}".format(epoch, val_error_rmse)) - if decay_start: - scheduler.step(val_error_rmse) - - # Early Stopping - if epoch > 5: - # if current validation error is not the best (current - best > 0) and is within tol of previous error, the model is stagnant. - if ((val_error_rmse - prev_loss) < tol) and (val_error_rmse - best_val_error) > 0.0: - es_tracker += 1 - # else if: current validation error is not the best (current - best > 0) and is greater than the best by tol, the model is overfitting. Bad epoch. - elif ((val_error_rmse - best_val_error) > tol) and (val_error_rmse - best_val_error) > 0.0: - es_tracker += 1 - # else if: if the current validation error is a new record, but not significant, the model is stagnant - elif (prev_best - best_val_error) < 0.001: - es_tracker += 1 - # else: model set a new record validation error. Reset early stopping tracker - else: - es_tracker = 0 - #TODO this framework does not detect oscillatory behavior about 'tol', though this has not been observed to occur in any case - # Check status of early stopping tracker. First try decaying to see if stagnation can be resolved, if not then terminate training - if es_tracker > es_patience: - if decay: # if decay is set to true, if early stopping criteria is triggered, begin LR scheduler and go back to previous model state and attempt LR decay. - if decay_attempts < 1: - decay_attempts += 1 - es_tracker = 0 - if verbose: - print("Performance plateau detected. Reverting model state and decaying learning rate.") - decay_start = True - thresh = (0.1 / np.sqrt(loss_descaler)) / hartree2cm # threshold is 0.1 wavenumbers - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, threshold=thresh, threshold_mode='abs', min_lr=0.05, cooldown=2, patience=10, verbose=verbose) - model.load_state_dict(saved_model_state_dict) - saved_optimizer_state_dict['param_groups'][0]['lr'] = lr*0.9 - optimizer.load_state_dict(saved_optimizer_state_dict) - # Since learning rate is decayed, override tolerance, patience, validation frequency for high-precision - #tol = 0.05 - #es_patience = 100 - #val_freq = 1 - continue - else: - prev_loss = val_error_rmse * 1.0 - if verbose: - print('Early stopping termination') - break - else: - prev_loss = val_error_rmse * 1.0 - if verbose: - print('Early stopping termination') - break - - # Handle exploding gradients - if epoch > 10: - if (val_error_rmse > prev_loss*10): # detect large increases in loss - if epoch > 60: # distinguish between exploding gradients at near converged models and early on exploding grads - if verbose: - print("Exploding gradient detected. Resuming previous model state and decaying learning rate") - model.load_state_dict(saved_model_state_dict) - saved_optimizer_state_dict['param_groups'][0]['lr'] = lr*0.5 - optimizer.load_state_dict(saved_optimizer_state_dict) - failures += 1 # if - if failures > 2: - break - else: - continue - else: - break - if val_error_rmse != val_error_rmse: # detect NaN - break - if ((prev_loss < 1.0) and (precision == 32)): # if 32 bit precision and model is giving very high accuracy, kill so the accuracy does not go beyond 32 bit precision - break - prev_loss = val_error_rmse * 1.0 # save previous loss to track improvement - - # Periodically save model state so we can reset under instability/overfitting/performance plateau - if epoch % 50 == 0: - saved_model_state_dict = copy.deepcopy(model.state_dict()) - saved_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) - - with torch.no_grad(): - test_pred = model(self.Xtest_lf) - test_loss = metric(test_pred, self.ytest_lf) - test_error_rmse = np.sqrt(test_loss.item() * loss_descaler) * hartree2cm - val_pred = model(self.Xvalid_lf) - val_loss = metric(val_pred, self.yvalid_lf) - val_error_rmse = np.sqrt(val_loss.item() * loss_descaler) * hartree2cm - full_pred = model(self.X_lf) - full_loss = metric(full_pred, self.y_lf) - full_error_rmse = np.sqrt(full_loss.item() * loss_descaler) * hartree2cm - """ # HF Training self.split_train_test(params, validation_size=self.nvalid, validation_size_lf=self.nvalid_lf, precision=precision) # split data, according to scaling hp's