From cc96d3b61c18a6dbbbbce3d3c6f8d63b5e670ed6 Mon Sep 17 00:00:00 2001 From: Gil Robalo Rei Date: Tue, 25 Nov 2025 17:02:13 +0100 Subject: [PATCH] feat: allow function in simulation model --- src/queens/drivers/_driver.py | 35 ++++++++++--------- src/queens/drivers/function.py | 9 +++-- src/queens/drivers/jobscript.py | 12 +++---- src/queens/models/adjoint.py | 4 ++- src/queens/models/simulation.py | 25 ++++++------- tests/unit_tests/drivers/test_jobscript.py | 6 ++-- tests/unit_tests/iterators/conftest.py | 16 ++++++--- .../iterators/test_elementary_effects.py | 6 ++-- .../test_latin_hypercube_sampling.py | 8 ++--- .../unit_tests/iterators/test_monte_carlo.py | 8 ++--- .../unit_tests/iterators/test_sobol_index.py | 12 +++---- .../iterators/test_sobol_sequence.py | 7 ++-- .../models/test_differentiable_adjoint.py | 2 +- .../models/test_differentiable_fd.py | 2 +- tests/unit_tests/models/test_simulation.py | 2 +- 15 files changed, 73 insertions(+), 81 deletions(-) diff --git a/src/queens/drivers/_driver.py b/src/queens/drivers/_driver.py index b1eb1e0d0..cd5fbe109 100644 --- a/src/queens/drivers/_driver.py +++ b/src/queens/drivers/_driver.py @@ -48,7 +48,7 @@ def __init__(self, parameters, files_to_copy=None): @abc.abstractmethod def run( self, - sample: np.ndarray, + inputs: dict, job_id: int, num_procs: int, experiment_dir: Path, @@ -57,7 +57,7 @@ def run( """Abstract method for driver run. Args: - sample (np.ndarray): Input sample + inputs (dict): Input sample job_id (int): Job ID num_procs (int): number of processors experiment_name (str): name of QUEENS experiment. @@ -67,17 +67,20 @@ def run( Results """ - def __call__(self, sample, job_id, num_procs, experiment_dir, experiment_name): - """Abstract method for driver run. - - Args: - sample (np.ndarray): Input sample - job_id (int): Job ID - num_procs (int): number of processors - experiment_name (str): name of QUEENS experiment. - experiment_dir (Path): Path to QUEENS experiment directory. - - Returns: - Result and potentially the gradient - """ - return self.run(sample, job_id, num_procs, experiment_dir, experiment_name) + def run_from_parameters( + self, + sample: np.ndarray, + job_id: int, + num_procs: int, + experiment_dir: Path, + experiment_name: str, + ) -> dict: + """Create inputs from parameters and run.""" + sample_dict = self.parameters.sample_as_dict(sample) + return self.run( + sample_dict, + job_id, + num_procs, + experiment_dir, + experiment_name, + ) diff --git a/src/queens/drivers/function.py b/src/queens/drivers/function.py index 2e2f881a0..bb9e39530 100644 --- a/src/queens/drivers/function.py +++ b/src/queens/drivers/function.py @@ -113,7 +113,7 @@ def reshaped_output_function(sample_dict): def run( self, - sample: np.ndarray, + inputs: dict, job_id: int, num_procs: int, experiment_dir: Path, @@ -122,7 +122,7 @@ def run( """Run the driver. Args: - sample (np.ndarray): Input sample + inputs (dict): Input sample job_id (int): Job ID num_procs (int): number of processors experiment_name (str): name of QUEENS experiment. @@ -131,8 +131,7 @@ def run( Returns: Result and potentially the gradient """ - sample_dict = self.parameters.sample_as_dict(sample) if self.function_requires_job_id: - sample_dict["job_id"] = job_id - results = self.function(sample_dict) + inputs["job_id"] = job_id + results = self.function(inputs) return results diff --git a/src/queens/drivers/jobscript.py b/src/queens/drivers/jobscript.py index 75466808e..9ac3d1449 100644 --- a/src/queens/drivers/jobscript.py +++ b/src/queens/drivers/jobscript.py @@ -20,8 +20,6 @@ from dataclasses import dataclass from pathlib import Path -import numpy as np - from queens.drivers._driver import Driver from queens.utils.exceptions import SubprocessError from queens.utils.injector import inject, inject_in_template @@ -193,7 +191,7 @@ def get_read_in_jobscript_template(jobscript_template): def run( self, - sample: np.ndarray, + inputs: dict, job_id: int, num_procs: int, experiment_dir: Path, @@ -202,7 +200,7 @@ def run( """Run the driver. Args: - sample (np.array): Input sample. + inputs (dict): Input sample. job_id (int): Job ID. num_procs (int): Number of processors. experiment_dir (Path): Path to QUEENS experiment directory. @@ -215,9 +213,7 @@ def run( job_id, experiment_dir ) - sample_dict = self.parameters.sample_as_dict(sample) - - metadata = SimulationMetadata(job_id=job_id, inputs=sample_dict, job_dir=job_dir) + metadata = SimulationMetadata(job_id=job_id, inputs=inputs, job_dir=job_dir) with metadata.time_code("prepare_input_files"): job_options = JobOptions( @@ -233,7 +229,7 @@ def run( # Create the input files self.prepare_input_files( - job_options.add_data_and_to_dict(sample_dict), experiment_dir, input_files + job_options.add_data_and_to_dict(inputs), experiment_dir, input_files ) jobscript_file = job_dir / self.jobscript_file_name diff --git a/src/queens/models/adjoint.py b/src/queens/models/adjoint.py index e08430998..a3b2302ff 100644 --- a/src/queens/models/adjoint.py +++ b/src/queens/models/adjoint.py @@ -83,6 +83,8 @@ def grad(self, samples, upstream_gradient): # evaluate the adjoint model gradient = self.create_result_dict_from_scheduler_output( - self.scheduler.evaluate(samples, self.gradient_driver, job_ids=last_job_ids) + self.scheduler.evaluate( + samples, self.gradient_driver.run_from_parameters, job_ids=last_job_ids + ) )["result"] return gradient diff --git a/src/queens/models/simulation.py b/src/queens/models/simulation.py index 413e9b49f..b0e5c03cb 100644 --- a/src/queens/models/simulation.py +++ b/src/queens/models/simulation.py @@ -16,30 +16,31 @@ import numpy as np +from queens.drivers._driver import Driver from queens.models._model import Model +from queens.schedulers._scheduler import Scheduler, SchedulerCallableSignature from queens.utils.logger_settings import log_init_args class Simulation(Model): - """Simulation model class. - - Attributes: - scheduler (Scheduler): Scheduler for the simulations - driver (Driver): Driver for the simulations - """ + """Simulation model class.""" @log_init_args - def __init__(self, scheduler, driver): + def __init__(self, scheduler: Scheduler, driver: Driver | SchedulerCallableSignature): """Initialize simulation model. Args: - scheduler (Scheduler): Scheduler for the simulations - driver (Driver): Driver for the simulations + scheduler: Scheduler for the simulations + driver: Driver for the simulations """ super().__init__() self.scheduler = scheduler - self.driver = driver - self.scheduler.copy_files_to_experiment_dir(self.driver.files_to_copy) + self.function: SchedulerCallableSignature + if isinstance(driver, Driver): + self.function = driver.run_from_parameters + self.scheduler.copy_files_to_experiment_dir(driver.files_to_copy) + else: + self.function = driver def _evaluate(self, samples: np.ndarray) -> dict: """Evaluate model with current set of input samples. @@ -51,7 +52,7 @@ def _evaluate(self, samples: np.ndarray) -> dict: response (dict): Response of the underlying model at input samples """ self.response = self.create_result_dict_from_scheduler_output( - self.scheduler.evaluate(samples, self.driver) + self.scheduler.evaluate(samples, self.function) ) return self.response diff --git a/tests/unit_tests/drivers/test_jobscript.py b/tests/unit_tests/drivers/test_jobscript.py index 8924dae86..31d44f604 100644 --- a/tests/unit_tests/drivers/test_jobscript.py +++ b/tests/unit_tests/drivers/test_jobscript.py @@ -244,7 +244,7 @@ def test_multiple_input_files(jobscript_driver, job_options, injected_input_file sample = np.array(list(sample_dict.values())) # Run the driver - jobscript_driver.run( + jobscript_driver.run_from_parameters( sample=sample, job_id=job_options.job_id, num_procs=job_options.num_procs, @@ -283,7 +283,7 @@ def test_error_in_jobscript_template( sample = np.array(list(sample_dict.values())) with expectation: - jobscript_driver.run( + jobscript_driver.run_from_parameters( sample=sample, job_id=job_options.job_id, num_procs=job_options.num_procs, @@ -314,7 +314,7 @@ def test_nonzero_exit_code( sample = np.array(list(sample_dict.values())) with expectation: - jobscript_driver.run( + jobscript_driver.run_from_parameters( sample=sample, job_id=job_options.job_id, num_procs=job_options.num_procs, diff --git a/tests/unit_tests/iterators/conftest.py b/tests/unit_tests/iterators/conftest.py index 67eb7aca7..9ba420532 100644 --- a/tests/unit_tests/iterators/conftest.py +++ b/tests/unit_tests/iterators/conftest.py @@ -17,7 +17,6 @@ from copy import deepcopy import pytest -from mock import Mock from queens.distributions.lognormal import LogNormal from queens.distributions.normal import Normal @@ -28,10 +27,19 @@ from queens.schedulers.local import Local -@pytest.fixture(name="default_simulation_model") -def fixture_default_simulation_model(): +@pytest.fixture(name="ishigami_90_uniform") +def fixture_ishigami_90_uniform(default_parameters_uniform_3d): """Default simulation model.""" - driver = Function(parameters=Mock(), function="ishigami90") + driver = Function(parameters=default_parameters_uniform_3d, function="ishigami90") + scheduler = Local(experiment_name="dummy_experiment_name") + model = Simulation(scheduler=scheduler, driver=driver) + return model + + +@pytest.fixture(name="ishigami_90_mixed") +def fixture_ishigami_90_mixed(default_parameters_mixed): + """Default simulation model.""" + driver = Function(parameters=default_parameters_mixed, function="ishigami90") scheduler = Local(experiment_name="dummy_experiment_name") model = Simulation(scheduler=scheduler, driver=driver) return model diff --git a/tests/unit_tests/iterators/test_elementary_effects.py b/tests/unit_tests/iterators/test_elementary_effects.py index 92e40eb20..4448b91c4 100644 --- a/tests/unit_tests/iterators/test_elementary_effects.py +++ b/tests/unit_tests/iterators/test_elementary_effects.py @@ -22,13 +22,11 @@ @pytest.fixture(name="default_elementary_effects_iterator") def fixture_default_elementary_effects_iterator( - global_settings, default_simulation_model, default_parameters_uniform_3d + global_settings, ishigami_90_uniform, default_parameters_uniform_3d ): """Default elementary effects iterator.""" - default_simulation_model.driver.parameters = default_parameters_uniform_3d - my_iterator = ElementaryEffects( - model=default_simulation_model, + model=ishigami_90_uniform, parameters=default_parameters_uniform_3d, global_settings=global_settings, num_trajectories=20, diff --git a/tests/unit_tests/iterators/test_latin_hypercube_sampling.py b/tests/unit_tests/iterators/test_latin_hypercube_sampling.py index aebbf884f..8f5ed06a9 100644 --- a/tests/unit_tests/iterators/test_latin_hypercube_sampling.py +++ b/tests/unit_tests/iterators/test_latin_hypercube_sampling.py @@ -21,16 +21,12 @@ @pytest.fixture(name="default_lhs_iterator") -def fixture_default_lhs_iterator( - global_settings, default_simulation_model, default_parameters_mixed -): +def fixture_default_lhs_iterator(global_settings, ishigami_90_mixed, default_parameters_mixed): """Default latin hypercube sampling iterator.""" - default_simulation_model.driver.parameters = default_parameters_mixed - # create LHS iterator # pylint: disable=duplicate-code my_iterator = LatinHypercubeSampling( - model=default_simulation_model, + model=ishigami_90_mixed, parameters=default_parameters_mixed, global_settings=global_settings, seed=42, diff --git a/tests/unit_tests/iterators/test_monte_carlo.py b/tests/unit_tests/iterators/test_monte_carlo.py index 7473deed3..a25419572 100644 --- a/tests/unit_tests/iterators/test_monte_carlo.py +++ b/tests/unit_tests/iterators/test_monte_carlo.py @@ -21,15 +21,11 @@ @pytest.fixture(name="default_mc_iterator") -def fixture_default_mc_iterator( - global_settings, default_simulation_model, default_parameters_mixed -): +def fixture_default_mc_iterator(global_settings, ishigami_90_mixed, default_parameters_mixed): """Default monte carlo iterator.""" - default_simulation_model.driver.parameters = default_parameters_mixed - # create LHS iterator my_iterator = MonteCarlo( - model=default_simulation_model, + model=ishigami_90_mixed, parameters=default_parameters_mixed, global_settings=global_settings, seed=42, diff --git a/tests/unit_tests/iterators/test_sobol_index.py b/tests/unit_tests/iterators/test_sobol_index.py index cbcff8260..cebf3bba8 100644 --- a/tests/unit_tests/iterators/test_sobol_index.py +++ b/tests/unit_tests/iterators/test_sobol_index.py @@ -22,13 +22,11 @@ @pytest.fixture(name="default_sobol_index_iterator") def fixture_default_sobol_index_iterator( - global_settings, default_simulation_model, default_parameters_uniform_3d + global_settings, ishigami_90_uniform, default_parameters_uniform_3d ): """Default sobol index iterator.""" - default_simulation_model.driver.parameters = default_parameters_uniform_3d - my_iterator = SobolIndex( - default_simulation_model, + ishigami_90_uniform, parameters=default_parameters_uniform_3d, global_settings=global_settings, seed=42, @@ -44,13 +42,11 @@ def fixture_default_sobol_index_iterator( @pytest.fixture(name="default_sobol_index_iterator_mixed") def fixture_default_sobol_index_iterator_mixed( - global_settings, default_simulation_model, default_parameters_mixed + global_settings, ishigami_90_mixed, default_parameters_mixed ): """Default sobol index iterator with different distributions.""" - default_simulation_model.driver.parameters = default_parameters_mixed - my_iterator = SobolIndex( - model=default_simulation_model, + model=ishigami_90_mixed, parameters=default_parameters_mixed, global_settings=global_settings, seed=42, diff --git a/tests/unit_tests/iterators/test_sobol_sequence.py b/tests/unit_tests/iterators/test_sobol_sequence.py index 71adf2792..caadb617c 100644 --- a/tests/unit_tests/iterators/test_sobol_sequence.py +++ b/tests/unit_tests/iterators/test_sobol_sequence.py @@ -21,13 +21,10 @@ @pytest.fixture(name="default_qmc_iterator") -def fixture_default_qmc_iterator( - global_settings, default_simulation_model, default_parameters_mixed -): +def fixture_default_qmc_iterator(global_settings, ishigami_90_mixed, default_parameters_mixed): """Sobol sequence iterator.""" - default_simulation_model.driver.parameters = default_parameters_mixed my_iterator = SobolSequence( - model=default_simulation_model, + model=ishigami_90_mixed, parameters=default_parameters_mixed, global_settings=global_settings, seed=42, diff --git a/tests/unit_tests/models/test_differentiable_adjoint.py b/tests/unit_tests/models/test_differentiable_adjoint.py index bd6476054..cddf3c6ec 100644 --- a/tests/unit_tests/models/test_differentiable_adjoint.py +++ b/tests/unit_tests/models/test_differentiable_adjoint.py @@ -53,7 +53,7 @@ def test_init(): adjoint_file=adjoint_file, ) assert model_obj.scheduler == scheduler - assert model_obj.driver == driver + assert model_obj.function == driver assert model_obj.gradient_driver == gradient_driver assert model_obj.adjoint_file == adjoint_file diff --git a/tests/unit_tests/models/test_differentiable_fd.py b/tests/unit_tests/models/test_differentiable_fd.py index cd898d4d0..0db209ba7 100644 --- a/tests/unit_tests/models/test_differentiable_fd.py +++ b/tests/unit_tests/models/test_differentiable_fd.py @@ -52,7 +52,7 @@ def test_init(): bounds=bounds, ) assert model_obj.scheduler == scheduler - assert model_obj.driver == driver + assert model_obj.function == driver assert model_obj.finite_difference_method == finite_difference_method assert model_obj.step_size == step_size np.testing.assert_equal(model_obj.bounds, np.array(bounds)) diff --git a/tests/unit_tests/models/test_simulation.py b/tests/unit_tests/models/test_simulation.py index 7dc201d24..e8202b014 100644 --- a/tests/unit_tests/models/test_simulation.py +++ b/tests/unit_tests/models/test_simulation.py @@ -28,7 +28,7 @@ def test_init(): driver = Mock() model_obj = Simulation(scheduler=scheduler, driver=driver) assert model_obj.scheduler == scheduler - assert model_obj.driver == driver + assert model_obj.function == driver @pytest.fixture(name="scheduler_response")