diff --git a/openadmet/models/anvil/workflow.py b/openadmet/models/anvil/workflow.py index 4478e4f5..e918bead 100644 --- a/openadmet/models/anvil/workflow.py +++ b/openadmet/models/anvil/workflow.py @@ -18,6 +18,7 @@ from openadmet.models.drivers import DriverType from openadmet.models.anvil.workflow_base import AnvilWorkflowBase +from openadmet.models.features.pairwise_featurizer import PairwiseFeaturizer def _safe_to_numpy(X): @@ -72,8 +73,8 @@ def check_no_finetuning(self): # Ensemble specified if self.ensemble: # Fine-tuning paths specified - if (self.parent_spec.procedure.ensemble.param_paths is not None) or ( - self.parent_spec.procedure.ensemble.serial_paths is not None + if (self.procedure.ensemble.param_paths is not None) or ( + self.procedure.ensemble.serial_paths is not None ): raise ValueError( "Finetuning from serialized ensemble models is not supported in this workflow." @@ -82,8 +83,8 @@ def check_no_finetuning(self): # No ensemble else: # Fine-tuning paths supplied - if (self.parent_spec.procedure.model.param_path is not None) or ( - self.parent_spec.procedure.model.serial_path is not None + if (self.procedure.model.param_path is not None) or ( + self.procedure.model.serial_path is not None ): raise ValueError( "Finetuning from serialized model is not supported in this workflow." @@ -117,7 +118,7 @@ def _train_ensemble(self, X_train_feat, y_train, output_dir, **kwargs): # Bootstrap iterations models = [] - for i in range(self.parent_spec.procedure.ensemble.n_models): + for i in range(self.procedure.ensemble.n_models): # Manage bootstrap directory bootstrap_dir = output_dir / f"bootstrap_{i}" bootstrap_dir.mkdir(parents=True, exist_ok=True) @@ -216,17 +217,19 @@ def run( data_dir.mkdir(parents=True, exist_ok=True) # Write recipe to output directory - self.parent_spec.to_recipe(output_dir / "anvil_recipe.yaml") + if self.parent_spec is not None: + self.parent_spec.to_recipe(output_dir / "anvil_recipe.yaml") # Split recipe into components and save recipe_components = Path(output_dir / "recipe_components") recipe_components.mkdir(parents=True, exist_ok=True) - self.parent_spec.to_multi_yaml( - metadata_yaml=recipe_components / "metadata.yaml", - procedure_yaml=recipe_components / "procedure.yaml", - data_yaml=recipe_components / "data.yaml", - report_yaml=recipe_components / "eval.yaml", - ) + if self.parent_spec is not None: + self.parent_spec.to_multi_yaml( + metadata_yaml=recipe_components / "metadata.yaml", + procedure_yaml=recipe_components / "procedure.yaml", + data_yaml=recipe_components / "data.yaml", + report_yaml=recipe_components / "eval.yaml", + ) # Log output directory information logger.info(f"Running workflow from directory {output_dir}") @@ -322,7 +325,7 @@ def run( self.model.calibrate_uncertainty( X_val_feat, y_val, - method=self.parent_spec.procedure.ensemble.calibration_method, + method=self.ensemble.calibration_method, ) # Save @@ -449,14 +452,11 @@ def _train( self, train_dataloader, val_dataloader, train_scaler, output_dir, **kwargs ): # Load model from disk - if ( - self.parent_spec.procedure.model.param_path is not None - and self.parent_spec.procedure.model.serial_path is not None - ): + if self.model.param_path is not None and self.model.serial_path is not None: logger.info("Loading model from disk, overrides any specified parameters.") self.model = self.model.deserialize( - self.parent_spec.procedure.model.param_path, - self.parent_spec.procedure.model.serial_path, + self.model.param_path, + self.model.serial_path, scaler=train_scaler, **kwargs, ) @@ -464,11 +464,9 @@ def _train( logger.info("Model loaded") # Optionally freeze weights - if self.parent_spec.procedure.model.freeze_weights is not None: + if self.model.freeze_weights is not None: logger.info(f"Freezing model weights") - self.model.freeze_weights( - **self.parent_spec.procedure.model.freeze_weights - ) + self.model.freeze_weights(**self.model.freeze_weights) logger.info(f"Model weights frozen") # Build model from scratch @@ -507,7 +505,7 @@ def _train_ensemble(self, X_train, y_train, val_dataloader, output_dir, **kwargs # Bootstrap iterations models = [] - for i in range(self.parent_spec.procedure.ensemble.n_models): + for i in range(self.ensemble.n_models): # Manage bootstrap directory bootstrap_dir = output_dir / f"bootstrap_{i}" bootstrap_dir.mkdir(parents=True, exist_ok=True) @@ -540,26 +538,24 @@ def _train_ensemble(self, X_train, y_train, val_dataloader, output_dir, **kwargs logger.info("Data featurized") # Load model from disk - if (self.parent_spec.procedure.ensemble.param_paths is not None) and ( - self.parent_spec.procedure.ensemble.serial_paths is not None + if (self.param_paths is not None) and ( + self.ensemble.serial_paths is not None ): logger.info( f"Loading model {i} from disk, overrides any specified parameters." ) self.model = self.model.deserialize( - self.parent_spec.procedure.ensemble.param_paths[i], - self.parent_spec.procedure.ensemble.serial_paths[i], + self.ensemble.param_paths[i], + self.ensemble.serial_paths[i], scaler=bootstrap_scaler, **kwargs, ) logger.info(f"Model {i} loaded") # Optionally freeze weights - if self.parent_spec.procedure.model.freeze_weights is not None: + if self.model.freeze_weights is not None: logger.info(f"Freezing weights for model {i}") - self.model.freeze_weights( - **self.parent_spec.procedure.model.freeze_weights - ) + self.model.freeze_weights(**self.model.freeze_weights) logger.info(f"Model {i} frozen") # Build model from scratch @@ -655,17 +651,19 @@ def run( data_dir.mkdir(parents=True, exist_ok=True) # Write recipe to output directory - self.parent_spec.to_recipe(output_dir / "anvil_recipe.yaml") + if self.parent_spec is not None: + self.parent_spec.to_recipe(output_dir / "anvil_recipe.yaml") # Split recipe into components and save recipe_components = Path(output_dir / "recipe_components") recipe_components.mkdir(parents=True, exist_ok=True) - self.parent_spec.to_multi_yaml( - metadata_yaml=recipe_components / "metadata.yaml", - procedure_yaml=recipe_components / "procedure.yaml", - data_yaml=recipe_components / "data.yaml", - report_yaml=recipe_components / "eval.yaml", - ) + if self.parent_spec is not None: + self.parent_spec.to_multi_yaml( + metadata_yaml=recipe_components / "metadata.yaml", + procedure_yaml=recipe_components / "procedure.yaml", + data_yaml=recipe_components / "data.yaml", + report_yaml=recipe_components / "eval.yaml", + ) # Log output directory information logger.info(f"Running workflow from directory {output_dir}") @@ -733,7 +731,7 @@ def run( logger.info("Data featurized") kwargs = {} - if self.parent_spec.procedure.feat.type == "PairwiseFeaturizer": + if isinstance(self.feat, PairwiseFeaturizer): kwargs["input_dim"] = train_dataset[0][0].shape[ -1 ] # this is the dimension of # of features, e.g. 1024 for ECFP4, variable for descriptors @@ -756,7 +754,7 @@ def run( self.model.calibrate_uncertainty( val_dataloader, y_val, - method=self.parent_spec.procedure.ensemble.calibration_method, + method=self.ensemble.calibration_method, accelerator=self.trainer.accelerator, devices=self.trainer.devices, ) diff --git a/openadmet/models/anvil/workflow_base.py b/openadmet/models/anvil/workflow_base.py index 77ee0922..5cd1f02f 100644 --- a/openadmet/models/anvil/workflow_base.py +++ b/openadmet/models/anvil/workflow_base.py @@ -61,7 +61,9 @@ class AnvilWorkflowBase(BaseModel): ensemble: EnsembleBase | None = None trainer: TrainerBase evals: list[EvalBase] - parent_spec: AnvilSpecification + parent_spec: Optional[AnvilSpecification] = ( + None # Optional reference to parent specification + ) debug: bool = False @abstractmethod