diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index a5d3315c..93ab107b 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -78,7 +78,9 @@ def job(method: Callable = None, **job_kwargs) -> Callable[..., Callable[..., Jo pass -def job(method: Callable = None, **job_kwargs): +def job( + method: Callable = None, **job_kwargs +) -> Callable[..., Job] | Callable[..., Callable[..., Job]]: """ Wrap a function to produce a :obj:`Job`. @@ -638,9 +640,7 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response: pass_manager_config(response.replace, passed_config) try: - output = jsanitize( - response.output, strict=True, enum_values=True, allow_bson=True - ) + output = jsanitize(response.output, strict=True, allow_bson=True) except AttributeError as err: raise RuntimeError( "Job output contained an object that is not MSONable and therefore " @@ -1121,7 +1121,7 @@ def as_dict(self) -> dict: # fireworks can't serialize functions and classes, so explicitly serialize to # the job recursively using monty to avoid issues - return jsanitize(d, strict=True, enum_values=True, allow_bson=True) + return jsanitize(d, strict=True, allow_bson=True) def __setattr__(self, key, value): """Handle setting attributes. Implements a special case for job name.""" @@ -1288,7 +1288,6 @@ def apply_schema(output: Any, schema: type[BaseModel] | None): return schema(**output) -@job(config=JobConfig(resolve_references=False, on_missing_references=OnMissing.NONE)) def store_inputs(inputs: Any) -> Any: """ Job to store inputs. @@ -1338,7 +1337,12 @@ def prepare_replace( # add a job with same UUID as the current job to store the outputs of the # flow; this job will inherit the metadata and output schema of the current # job - store_output_job = store_inputs(replace.output) + new_config = JobConfig( + resolve_references=False, on_missing_references=OnMissing.NONE + ) + store_output_job = Job( + store_inputs, function_args=(replace.output,), config=new_config + ) store_output_job.set_uuid(current_job.uuid) store_output_job.index = current_job.index + 1 store_output_job.metadata = current_job.metadata @@ -1370,17 +1374,15 @@ def pass_manager_config( """ Pass the manager config on to any jobs in the jobs array. + Merge with already specified manager config. + Parameters ---------- jobs A job, flow, or list of jobs/flows. manager_config A manager config to pass on. - metadata - Metadata to pass on. """ - from copy import deepcopy - all_jobs: list[Job] = [] def get_jobs(arg): @@ -1400,4 +1402,4 @@ def get_jobs(arg): # update manager config for ajob in all_jobs: - ajob.config.manager_config = deepcopy(manager_config) + ajob.config.manager_config = manager_config | ajob.config.manager_config diff --git a/src/jobflow/core/reference.py b/src/jobflow/core/reference.py index 10d774b3..07e1b1c9 100644 --- a/src/jobflow/core/reference.py +++ b/src/jobflow/core/reference.py @@ -394,7 +394,7 @@ def find_and_get_references(arg: Any) -> tuple[OutputReference, ...]: # argument is a primitive, we won't find a reference here return () - arg = jsanitize(arg, strict=True, enum_values=True, allow_bson=True) + arg = jsanitize(arg, strict=True, allow_bson=True) # recursively find any reference classes locations = find_key_value(arg, "@class", "OutputReference") @@ -458,7 +458,7 @@ def find_and_resolve_references( return arg # serialize the argument to a dictionary - encoded_arg = jsanitize(arg, strict=True, enum_values=True, allow_bson=True) + encoded_arg = jsanitize(arg, strict=True, allow_bson=True) # recursively find any reference classes locations = find_key_value(encoded_arg, "@class", "OutputReference") diff --git a/tests/core/test_job.py b/tests/core/test_job.py index 90eadea8..bc47b0d3 100644 --- a/tests/core/test_job.py +++ b/tests/core/test_job.py @@ -940,15 +940,22 @@ def add_schema_replace(a, b): def test_store_inputs(memory_jobstore): - from jobflow.core.job import Job, OutputReference, store_inputs + from jobflow.core.job import ( + Job, + OutputReference, + store_inputs, + JobConfig, + OnMissing, + ) - test_job = store_inputs(1) + configs = JobConfig(resolve_references=False, on_missing_references=OnMissing.NONE) + test_job = Job(store_inputs, (1,), config=configs) test_job.run(memory_jobstore) output = memory_jobstore.query_one({"uuid": test_job.uuid}, ["output"])["output"] assert output == 1 ref = OutputReference("abcd") - test_job = store_inputs(ref) + test_job = Job(store_inputs, (ref,), config=configs) test_job.run(memory_jobstore) output = memory_jobstore.query_one({"uuid": test_job.uuid}, ["output"])["output"] assert OutputReference.from_dict(output) == ref