Skip to content

Commit eda2a65

Browse files
authored
Merge pull request #544 from gpetretto/devel
jsanitize fireworks Task
2 parents c211030 + 0debb38 commit eda2a65

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

src/jobflow/managers/fireworks.py

+15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import typing
66

77
from fireworks import FiretaskBase, Firework, FWAction, Workflow, explicit_serialize
8+
from fireworks.utilities.fw_serializers import recursive_serialize, serialize_fw
9+
from monty.json import jsanitize
810

911
if typing.TYPE_CHECKING:
1012
from collections.abc import Sequence
@@ -197,3 +199,16 @@ def run_task(self, fw_spec):
197199
defuse_workflow=response.stop_jobflow,
198200
defuse_children=response.stop_children,
199201
)
202+
203+
@serialize_fw
204+
@recursive_serialize
205+
def to_dict(self) -> dict:
206+
"""
207+
Serialize version of the FireTask.
208+
209+
Overrides the original method to explicitly jsanitize the Job
210+
to handle cases not properly handled by fireworks, like a Callable.
211+
"""
212+
d = dict(self)
213+
d["job"] = jsanitize(d["job"].as_dict())
214+
return d

tests/managers/conftest.py

+22
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,25 @@ def _gen():
399399
return Flow([replace, simple], simple.output, order=JobOrder.LINEAR)
400400

401401
return _gen
402+
403+
404+
@pytest.fixture(scope="session")
405+
def maker_with_callable():
406+
from dataclasses import dataclass
407+
from typing import Callable
408+
409+
from jobflow.core.job import job
410+
from jobflow.core.maker import Maker
411+
412+
global TestCallableMaker
413+
414+
@dataclass
415+
class TestCallableMaker(Maker):
416+
f: Callable
417+
name: str = "TestCallableMaker"
418+
419+
@job
420+
def make(self, a, b):
421+
return self.f([a, b])
422+
423+
return TestCallableMaker

tests/managers/test_fireworks.py

+33
Original file line numberDiff line numberDiff line change
@@ -659,3 +659,36 @@ def test_external_reference(lpad, mongo_jobstore, fw_dir, simple_job, capsys):
659659
# check response
660660
result2 = mongo_jobstore.query_one({"uuid": uuid2})
661661
assert result2["output"] == "12345_end_end"
662+
663+
664+
def test_maker_flow(lpad, mongo_jobstore, fw_dir, maker_with_callable, capsys):
665+
from fireworks.core.rocket_launcher import rapidfire
666+
667+
from jobflow.core.flow import Flow
668+
from jobflow.managers.fireworks import flow_to_workflow
669+
670+
j = maker_with_callable(f=sum).make(a=1, b=2)
671+
672+
flow = Flow([j])
673+
uuid = flow[0].uuid
674+
675+
wf = flow_to_workflow(flow, mongo_jobstore)
676+
fw_ids = lpad.add_wf(wf)
677+
678+
# run the workflow
679+
rapidfire(lpad)
680+
681+
# check workflow completed
682+
fw_id = next(iter(fw_ids.values()))
683+
wf = lpad.get_wf_by_fw_id(fw_id)
684+
685+
assert all(s == "COMPLETED" for s in wf.fw_states.values())
686+
687+
# check store has the activity output
688+
result = mongo_jobstore.query_one({"uuid": uuid})
689+
assert result["output"] == 3
690+
691+
# check logs printed
692+
captured = capsys.readouterr()
693+
assert "INFO Starting job - TestCallableMaker" in captured.out
694+
assert "INFO Finished job - TestCallableMaker" in captured.out

0 commit comments

Comments
 (0)