Skip to content

Commit eb321c4

Browse files
authored
Add a test that runs with an exec_config and checks an environment variable is set in the job (#56)
1 parent 6d68367 commit eb321c4

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/jobflow_remote/testing/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,10 @@ def arithmetic(
3434
return op(a, b)
3535

3636
return None
37+
38+
39+
@job
40+
def check_env_var() -> str:
41+
import os
42+
43+
return os.environ.get("TESTING_ENV_VAR", "unset")

tests/integration/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def write_tmp_settings(
229229
connect_kwargs={"allow_agent": False, "look_for_keys": False},
230230
),
231231
},
232+
exec_config={"test": {"export": {"TESTING_ENV_VAR": random_project_name}}},
232233
runner=dict(
233234
delay_checkout=1,
234235
delay_check_run_status=1,

tests/integration/test_slurm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,32 @@ def test_expected_failure(worker, job_controller):
205205

206206
assert job_controller.count_jobs(state=JobState.FAILED) == 2
207207
assert job_controller.count_flows(state=FlowState.FAILED) == 1
208+
209+
210+
@pytest.mark.parametrize(
211+
"worker",
212+
["test_local_worker", "test_remote_worker"],
213+
)
214+
def test_exec_config(worker, job_controller, random_project_name):
215+
"""Tests that an environment variable set in the exec config
216+
is available to the job.
217+
218+
"""
219+
220+
from jobflow_remote import submit_flow
221+
from jobflow_remote.jobs.runner import Runner
222+
from jobflow_remote.testing import check_env_var
223+
224+
job = check_env_var()
225+
submit_flow(job, worker=worker, exec_config="test")
226+
227+
assert job_controller.count_jobs({}) == 1
228+
assert len(job_controller.get_jobs({})) == 1
229+
assert job_controller.count_flows({}) == 1
230+
231+
runner = Runner()
232+
runner.run(ticks=5)
233+
234+
job = job_controller.get_jobs({})[0]
235+
output = job_controller.jobstore.get_output(uuid=job["uuid"])
236+
assert output == random_project_name

0 commit comments

Comments
 (0)