Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/jobflow/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def job(method: Callable = None, **job_kwargs) -> Callable[..., Callable[..., Jo
pass


def job(method: Callable = None, **job_kwargs):
def job(
method: Callable = None, **job_kwargs
) -> Callable[..., Job] | Callable[..., Callable[..., Job]]:
"""
Wrap a function to produce a :obj:`Job`.

Expand Down Expand Up @@ -638,9 +640,7 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response:
pass_manager_config(response.replace, passed_config)

try:
output = jsanitize(
response.output, strict=True, enum_values=True, allow_bson=True
)
output = jsanitize(response.output, strict=True, allow_bson=True)
except AttributeError as err:
raise RuntimeError(
"Job output contained an object that is not MSONable and therefore "
Expand Down Expand Up @@ -1121,7 +1121,7 @@ def as_dict(self) -> dict:

# fireworks can't serialize functions and classes, so explicitly serialize to
# the job recursively using monty to avoid issues
return jsanitize(d, strict=True, enum_values=True, allow_bson=True)
return jsanitize(d, strict=True, allow_bson=True)

def __setattr__(self, key, value):
"""Handle setting attributes. Implements a special case for job name."""
Expand Down Expand Up @@ -1288,7 +1288,6 @@ def apply_schema(output: Any, schema: type[BaseModel] | None):
return schema(**output)


@job(config=JobConfig(resolve_references=False, on_missing_references=OnMissing.NONE))
def store_inputs(inputs: Any) -> Any:
"""
Job to store inputs.
Expand Down Expand Up @@ -1338,7 +1337,12 @@ def prepare_replace(
# add a job with same UUID as the current job to store the outputs of the
# flow; this job will inherit the metadata and output schema of the current
# job
store_output_job = store_inputs(replace.output)
new_config = JobConfig(
resolve_references=False, on_missing_references=OnMissing.NONE
)
store_output_job = Job(
store_inputs, function_args=(replace.output,), config=new_config
)
store_output_job.set_uuid(current_job.uuid)
store_output_job.index = current_job.index + 1
store_output_job.metadata = current_job.metadata
Expand Down Expand Up @@ -1370,17 +1374,15 @@ def pass_manager_config(
"""
Pass the manager config on to any jobs in the jobs array.

Merge with already specified manager config.

Parameters
----------
jobs
A job, flow, or list of jobs/flows.
manager_config
A manager config to pass on.
metadata
Metadata to pass on.
"""
from copy import deepcopy

all_jobs: list[Job] = []

def get_jobs(arg):
Expand All @@ -1400,4 +1402,4 @@ def get_jobs(arg):

# update manager config
for ajob in all_jobs:
ajob.config.manager_config = deepcopy(manager_config)
ajob.config.manager_config = manager_config | ajob.config.manager_config
4 changes: 2 additions & 2 deletions src/jobflow/core/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def find_and_get_references(arg: Any) -> tuple[OutputReference, ...]:
# argument is a primitive, we won't find a reference here
return ()

arg = jsanitize(arg, strict=True, enum_values=True, allow_bson=True)
arg = jsanitize(arg, strict=True, allow_bson=True)

# recursively find any reference classes
locations = find_key_value(arg, "@class", "OutputReference")
Expand Down Expand Up @@ -458,7 +458,7 @@ def find_and_resolve_references(
return arg

# serialize the argument to a dictionary
encoded_arg = jsanitize(arg, strict=True, enum_values=True, allow_bson=True)
encoded_arg = jsanitize(arg, strict=True, allow_bson=True)

# recursively find any reference classes
locations = find_key_value(encoded_arg, "@class", "OutputReference")
Expand Down
13 changes: 10 additions & 3 deletions tests/core/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,15 +940,22 @@ def add_schema_replace(a, b):


def test_store_inputs(memory_jobstore):
from jobflow.core.job import Job, OutputReference, store_inputs
from jobflow.core.job import (
Job,
OutputReference,
store_inputs,
JobConfig,
OnMissing,
)

test_job = store_inputs(1)
configs = JobConfig(resolve_references=False, on_missing_references=OnMissing.NONE)
test_job = Job(store_inputs, (1,), config=configs)
test_job.run(memory_jobstore)
output = memory_jobstore.query_one({"uuid": test_job.uuid}, ["output"])["output"]
assert output == 1

ref = OutputReference("abcd")
test_job = store_inputs(ref)
test_job = Job(store_inputs, (ref,), config=configs)
test_job.run(memory_jobstore)
output = memory_jobstore.query_one({"uuid": test_job.uuid}, ["output"])["output"]
assert OutputReference.from_dict(output) == ref
Expand Down