Skip to content

Commit 89e3d7c

Browse files
committed
make JobDoc available at runtime
1 parent 8a78800 commit 89e3d7c

File tree

5 files changed

+73
-4
lines changed

5 files changed

+73
-4
lines changed

src/jobflow_remote/jobs/run.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212

1313
from jobflow import JobStore, initialize_logger
1414
from jobflow.core.flow import get_flow
15+
from monty.design_patterns import singleton
1516
from monty.os import cd
1617
from monty.serialization import dumpfn, loadfn
1718
from monty.shutil import decompress_file
1819

1920
from jobflow_remote.jobs.batch import LocalBatchManager
20-
from jobflow_remote.jobs.data import IN_FILENAME, OUT_FILENAME
21+
from jobflow_remote.jobs.data import IN_FILENAME, OUT_FILENAME, JobDoc
2122
from jobflow_remote.remote.data import get_job_path, get_store_file_paths
2223
from jobflow_remote.utils.log import initialize_remote_run_log
2324

@@ -29,6 +30,20 @@
2930
logger = logging.getLogger(__name__)
3031

3132

33+
@singleton
34+
class JfrState:
35+
"""State of the current job being executed."""
36+
37+
job_doc: JobDoc = None
38+
39+
def reset(self):
40+
"""Reset the current state."""
41+
self.job_doc = None
42+
43+
44+
CURRENT_JOBDOC: JfrState = JfrState()
45+
46+
3247
def run_remote_job(run_dir: str | Path = ".") -> None:
3348
"""Run the job."""
3449
initialize_remote_run_log()
@@ -42,6 +57,10 @@ def run_remote_job(run_dir: str | Path = ".") -> None:
4257

4358
job: Job = in_data["job"]
4459
store = in_data["store"]
60+
job_doc_dict = in_data.get("job_doc", None)
61+
if job_doc_dict:
62+
job_doc_dict["job"] = job
63+
JfrState().job_doc = JobDoc.model_validate(job_doc_dict)
4564

4665
store.connect()
4766

@@ -91,6 +110,8 @@ def run_remote_job(run_dir: str | Path = ".") -> None:
91110
"end_time": datetime.datetime.utcnow(),
92111
}
93112
dumpfn(output, OUT_FILENAME)
113+
finally:
114+
JfrState().reset()
94115

95116

96117
def run_batch_jobs(

src/jobflow_remote/jobs/runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def upload(self, lock: MongoLock) -> None:
648648
logger.error(err_msg)
649649
raise RemoteError(err_msg, no_retry=False)
650650

651-
serialized_input = get_remote_in_file(job_dict, remote_store)
651+
serialized_input = get_remote_in_file(job_dict, remote_store, doc)
652652

653653
path_file = Path(remote_path, IN_FILENAME)
654654
host.put(serialized_input, str(path_file))

src/jobflow_remote/remote/data.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,18 @@ def get_local_data_path(
5656
return get_job_path(job_id, index, local_base_dir)
5757

5858

59-
def get_remote_in_file(job, remote_store):
59+
def get_remote_in_file(job, remote_store, job_doc=None):
60+
# remove the job from the job_doc, if present.
61+
# Create the copy from scratch to avoid allocating the job multiple
62+
# times if it is big
63+
job_doc_copy = None
64+
if job_doc is not None:
65+
job_doc_copy = {k: v for k, v in job_doc.items() if k not in ("job", "_id")}
66+
# the document is likely locked when getting here.
67+
job_doc_copy["lock_id"] = None
68+
job_doc_copy["lock_time"] = None
6069
d = jsanitize(
61-
{"job": job, "store": remote_store},
70+
{"job": job, "store": remote_store, "job_doc": job_doc_copy},
6271
strict=True,
6372
allow_bson=True,
6473
enum_values=True,

src/jobflow_remote/testing/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,10 @@ def ignore_input(a: int) -> int:
9696
Allows to test flows with failed parents
9797
"""
9898
return 1
99+
100+
101+
@job
102+
def current_jobdoc():
103+
from jobflow_remote.jobs.run import CURRENT_JOBDOC
104+
105+
return CURRENT_JOBDOC.job_doc

tests/db/jobs/test_run.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
def test_current_jobdoc(job_controller, runner):
2+
from jobflow_remote import submit_flow
3+
from jobflow_remote.jobs.run import CURRENT_JOBDOC, JfrState
4+
from jobflow_remote.testing import current_jobdoc
5+
6+
j = current_jobdoc()
7+
submit_flow([j], worker="test_local_worker")
8+
runner.run_one_job()
9+
10+
job_output = job_controller.jobstore.get_output(uuid=j.uuid)
11+
job_doc = job_controller.get_job_doc(job_id=j.uuid).as_db_dict()
12+
for k in job_doc:
13+
# some keys do not match
14+
if k not in (
15+
"state",
16+
"end_time",
17+
"start_time",
18+
"updated_on",
19+
"remote",
20+
"run_dir",
21+
"created_on",
22+
):
23+
assert job_doc[k] == job_output[k]
24+
25+
# check that CURRENT_JOBDOC is a singleton and can be set
26+
s = JfrState()
27+
assert s.job_doc is None
28+
assert CURRENT_JOBDOC.job_doc is None
29+
s.job_doc = job_doc
30+
assert CURRENT_JOBDOC.job_doc == job_doc
31+
s.reset()
32+
assert CURRENT_JOBDOC.job_doc is None

0 commit comments

Comments
 (0)