diff --git a/job_creator/jobcreator/job_creator.py b/job_creator/jobcreator/job_creator.py index ced77bd..33b81eb 100644 --- a/job_creator/jobcreator/job_creator.py +++ b/job_creator/jobcreator/job_creator.py @@ -9,7 +9,14 @@ from jobcreator.utils import load_kubernetes_config, logger -def _setup_smb_pv(pv_name: str, secret_name: str, secret_namespace: str, source: str, mount_options: list[str]) -> None: +def _setup_smb_pv( + pv_name: str, + secret_name: str, + secret_namespace: str, + source: str, + mount_options: list[str], + access_mode: str = "ReadOnlyMany", +) -> None: """ Sets up an smb PV using the loaded kubeconfig as a destination :param pv_name: str, The name given to the smb-pv when it's made @@ -17,7 +24,8 @@ def _setup_smb_pv(pv_name: str, secret_name: str, secret_namespace: str, source: :param secret_namespace: str, the namespace of the secret :param source: str, The IP/url/uri that is used to mount the smb share :param mount_options: list, The mount options for the smb share - :return: str, the name of the archive PV + :param access_mode: str, The access mode for the PV. Defaults to "ReadOnlyMany" + :return: str, the name of the PV """ metadata = client.V1ObjectMeta(name=pv_name, annotations={"pv.kubernetes.io/provisioned-by": "smb.csi.k8s.io"}) secret_ref = client.V1SecretReference(name=secret_name, namespace=secret_namespace) @@ -30,13 +38,13 @@ def _setup_smb_pv(pv_name: str, secret_name: str, secret_namespace: str, source: ) spec = client.V1PersistentVolumeSpec( capacity={"storage": "1000Gi"}, - access_modes=["ReadOnlyMany"], + access_modes=[access_mode], persistent_volume_reclaim_policy="Retain", mount_options=mount_options, csi=csi, ) - archive_pv = client.V1PersistentVolume(api_version="v1", kind="PersistentVolume", metadata=metadata, spec=spec) - client.CoreV1Api().create_persistent_volume(archive_pv) + pv = client.V1PersistentVolume(api_version="v1", kind="PersistentVolume", metadata=metadata, spec=spec) + client.CoreV1Api().create_persistent_volume(pv) def _setup_pvc(pvc_name: str, pv_name: str, namespace: str, access_mode: str = "ReadOnlyMany") -> None: @@ -163,6 +171,15 @@ def _setup_ceph_pv( return pv_name +def _setup_ngem_pv_and_pvcs(job_name: str, namespace: str, pv_names: list[str], pvc_names: list[str]) -> None: + ngem_pv_name = f"{job_name}-ngem-pv-smb" + ngem_pvc_name = f"{job_name}-ngem-pvc" + _setup_smb_pv(ngem_pv_name, "archive-creds", namespace, "//isis.cclrc.ac.uk/Science", [], "ReadWriteMany") + _setup_pvc(ngem_pvc_name, ngem_pv_name, namespace, "ReadWriteMany") + pv_names.append(ngem_pv_name) + pvc_names.append(ngem_pvc_name) + + def _setup_imat_pv_and_pvcs(job_name: str, namespace: str, pv_names: list[str], pvc_names: list[str]) -> None: imat_pv_name = f"{job_name}-ndximat-pv-smb" imat_pvc_name = f"{job_name}-ndximat-pvc" @@ -385,6 +402,13 @@ def spawn_job( # noqa: PLR0913 client.V1VolumeMount(name="extras-mount", mount_path="/extras"), ] # Setup special PVs and add them to the volume mounts + if "ngem" in special_pvs: + _setup_ngem_pv_and_pvcs(job_name, job_namespace, pv_names, pvc_names) + ngem_pvc_source = client.V1PersistentVolumeClaimVolumeSource( + claim_name=f"{job_name}-ngem-pvc", read_only=False + ) + volumes.append(client.V1Volume(name="ngem-mount", persistent_volume_claim=ngem_pvc_source)) + volumes_mounts.append(client.V1VolumeMount(name="ngem-mount", mount_path="/ngem")) if "imat" in special_pvs: _setup_imat_pv_and_pvcs(job_name, job_namespace, pv_names, pvc_names) imat_pvc_source = client.V1PersistentVolumeClaimVolumeSource( diff --git a/job_creator/jobcreator/main.py b/job_creator/jobcreator/main.py index caeca17..6c9675c 100644 --- a/job_creator/jobcreator/main.py +++ b/job_creator/jobcreator/main.py @@ -46,7 +46,19 @@ CONSUMER_PASSWORD = os.environ.get("QUEUE_PASSWORD", "") REDUCE_USER_ID = os.environ.get("REDUCE_USER_ID", "") JOB_NAMESPACE = os.environ.get("JOB_NAMESPACE", "fia") -JOB_CREATOR = JobCreator(dev_mode=DEV_MODE, watcher_sha=WATCHER_SHA) +JOB_CREATOR: JobCreator | None = None + + +def get_job_creator() -> JobCreator: + """Use this Singleton pattern to allow for trivial mocking within testing infrastructure, not strictly needed but + trivialises test implementation for mocking out the JobCreator.""" + global JOB_CREATOR # noqa: PLW0603 + if JOB_CREATOR is None: + if WATCHER_SHA is None: + raise OSError("WATCHER_SHA not set in the environment, please add it.") + JOB_CREATOR = JobCreator(dev_mode=DEV_MODE, watcher_sha=WATCHER_SHA) + return JOB_CREATOR + CEPH_CREDS_SECRET_NAME = os.environ.get("CEPH_CREDS_SECRET_NAME", "ceph-creds") CEPH_CREDS_SECRET_NAMESPACE = os.environ.get("CEPH_CREDS_SECRET_NAMESPACE", "fia") @@ -59,7 +71,7 @@ MAX_TIME_TO_COMPLETE = int(os.environ.get("MAX_TIME_TO_COMPLETE", str(60 * 60 * 6))) -def _generate_special_pvs(instrument: str) -> list[str]: +def _generate_special_pvs(instrument: str, additional_values: dict[str, Any]) -> list[str]: """ A generic function for, based on passed args, returning what the special persistent volumes should be. """ @@ -68,19 +80,31 @@ def _generate_special_pvs(instrument: str) -> list[str]: match instrument.lower(): case "imat": logger.info("Special PV for %s added.", instrument) - special_pvs.append("imat") + if "ngem" in additional_values and additional_values["ngem"] == "true": + special_pvs.append("ngem") + else: + special_pvs.append("imat") + case "ines": + logger.info("Special PV for %s added.", instrument) + if "ngem" in additional_values and additional_values["ngem"] == "true": + special_pvs.append("ngem") + else: + special_pvs.append("ines") case _: logger.info("No special PV needed for %s", instrument) return special_pvs -def _select_runner_image(instrument: str) -> str: +def _select_runner_image(instrument: str, additional_values: dict[str, Any]) -> str: """ A generic function for, based on passed args, returning what the runner that should be used. """ match instrument.lower(): case "imat": + if "ngem" in additional_values and additional_values["ngem"] == "true": + # For ngem we want to return the default mantid runner. INES always wants mantid default runner. + return DEFAULT_RUNNER if IMAGING_RUNNER_SHA is not None: logger.info("Imaging runner image selected for %s ", instrument) return IMAGING_RUNNER @@ -91,7 +115,9 @@ def _select_runner_image(instrument: str) -> str: return DEFAULT_RUNNER -def _select_taints_and_affinity(instrument: str) -> tuple[list[dict[str, Any]], dict[str, Any] | None]: +def _select_taints_and_affinity( + instrument: str, additional_values: dict[str, Any] +) -> tuple[list[dict[str, Any]], dict[str, Any] | None]: """ A generic function for, based on passed args, returning what the runner that should be used. """ @@ -100,9 +126,10 @@ def _select_taints_and_affinity(instrument: str) -> tuple[list[dict[str, Any]], match instrument.lower(): case "imat": - logger.info("Applying taint to the job on instrument %s", instrument) - taints.append({"key": "nvidia.com/gpu", "effect": "NoSchedule", "operator": "Exists"}) - affinity = {"key": "node-type", "operator": "In", "values": ["gpu-worker"]} + if "ngem" not in additional_values or additional_values["ngem"] != "true": + logger.info("Applying taint to the job on instrument %s", instrument) + taints.append({"key": "nvidia.com/gpu", "effect": "NoSchedule", "operator": "Exists"}) + affinity = {"key": "node-type", "operator": "In", "values": ["gpu-worker"]} case _: logger.info("No taints applied to %s runners", instrument) @@ -146,7 +173,7 @@ def process_simple_message(message: dict[str, Any]) -> None: {"user_number": str(user_number)} if user_number else {"experiment_number": str(experiment_number)} ) ceph_mount_path = create_ceph_mount_path_simple(**ceph_mount_path_kwargs) - JOB_CREATOR.spawn_job( + get_job_creator().spawn_job( job_name=job_name, script=script, job_namespace=JOB_NAMESPACE, @@ -185,12 +212,16 @@ def process_rerun_message(message: dict[str, Any]) -> None: rb_number=str(message["rb_number"]), ) - special_pvs = _generate_special_pvs(instrument=message["instrument"]) - taints, affinity = _select_taints_and_affinity(instrument=message["instrument"]) + special_pvs = _generate_special_pvs( + instrument=message["instrument"], additional_values=message.get("additional_values", {}) + ) + taints, affinity = _select_taints_and_affinity( + instrument=message["instrument"], additional_values=message.get("additional_values", {}) + ) # Add UUID which will avoid collisions for reruns job_name = f"run-{str(message['filename']).lower()}-{uuid.uuid4().hex!s}" - JOB_CREATOR.spawn_job( + get_job_creator().spawn_job( job_name=job_name, script=script, job_namespace=JOB_NAMESPACE, @@ -226,7 +257,7 @@ def process_autoreduction_message(message: dict[str, Any]) -> None: instrument_name = message["instrument"] runner_image = message.get("runner_image") if runner_image is None: - runner_image = _select_runner_image(instrument_name) + runner_image = _select_runner_image(instrument_name, message["additional_values"]) runner_image = find_sha256_of_image(runner_image) autoreduction_request = { "filename": filename, @@ -242,8 +273,10 @@ def process_autoreduction_message(message: dict[str, Any]) -> None: "runner_image": runner_image, } - special_pvs = _generate_special_pvs(instrument=instrument_name) - taints, affinity = _select_taints_and_affinity(instrument=message["instrument"]) + special_pvs = _generate_special_pvs(instrument=instrument_name, additional_values=message["additional_values"]) + taints, affinity = _select_taints_and_affinity( + instrument=message["instrument"], additional_values=message["additional_values"] + ) # Add UUID which will avoid collisions for reruns job_name = f"run-{filename.lower()}-{uuid.uuid4().hex!s}" @@ -253,7 +286,7 @@ def process_autoreduction_message(message: dict[str, Any]) -> None: autoreduction_request=autoreduction_request, ) ceph_mount_path = create_ceph_mount_path_autoreduction(instrument_name, rb_number) - JOB_CREATOR.spawn_job( + get_job_creator().spawn_job( job_name=job_name, script=script, job_namespace=JOB_NAMESPACE, diff --git a/job_creator/test/test_job_creator.py b/job_creator/test/test_job_creator.py index 32dc99f..68469d5 100644 --- a/job_creator/test/test_job_creator.py +++ b/job_creator/test/test_job_creator.py @@ -4,9 +4,13 @@ from jobcreator.job_creator import ( JobCreator, + _generate_affinities, + _generate_tolerations_from_taints, _setup_ceph_pv, _setup_extras_pv, _setup_extras_pvc, + _setup_imat_pv_and_pvcs, + _setup_ngem_pv_and_pvcs, _setup_pvc, _setup_smb_pv, ) @@ -155,6 +159,83 @@ def test_setup_extras_pv(client): client.V1SecretReference.assert_called_once_with(name="manila-creds", namespace=secret_namespace) +EXPECTED_TOLERATIONS_COUNT = 2 + + +@mock.patch("jobcreator.job_creator.client") +def test_generate_tolerations_from_taints(client): + taints = [ + {"key": "key1", "value": "value1", "operator": "Equal", "effect": "NoSchedule"}, + {"key": "key2", "operator": "Exists", "effect": "NoExecute"}, + ] + tolerations = _generate_tolerations_from_taints(taints) + + assert len(tolerations) == EXPECTED_TOLERATIONS_COUNT + client.V1Toleration.assert_has_calls( + [ + call(key="key1", value="value1", operator="Equal", effect="NoSchedule"), + call(key="key2", value=None, operator="Exists", effect="NoExecute"), + ] + ) + + +@mock.patch("jobcreator.job_creator.client") +def test_generate_affinities_none(client): + affinity = _generate_affinities(None) + assert affinity == client.V1Affinity.return_value + client.V1Affinity.assert_called_once_with(pod_anti_affinity=client.V1PodAntiAffinity.return_value) + + +@mock.patch("jobcreator.job_creator.logger") +@mock.patch("jobcreator.job_creator.client") +def test_generate_affinities_missing_key(client, logger): + node_affinity_dict = {"key": "some-key", "operator": "In"} # missing "values" + _generate_affinities(node_affinity_dict) + logger.error.assert_called_once() + client.V1Affinity.assert_called_once_with(pod_anti_affinity=client.V1PodAntiAffinity.return_value) + + +@mock.patch("jobcreator.job_creator.client") +def test_generate_affinities_valid(client): + node_affinity_dict = {"key": "some-key", "operator": "In", "values": ["val1"]} + _generate_affinities(node_affinity_dict) + client.V1Affinity.assert_called_once_with( + pod_anti_affinity=client.V1PodAntiAffinity.return_value, node_affinity=client.V1NodeAffinity.return_value + ) + + +@mock.patch("jobcreator.job_creator._setup_pvc") +@mock.patch("jobcreator.job_creator._setup_smb_pv") +def test_setup_ngem_pv_and_pvcs(setup_smb_pv, setup_pvc): + pv_names = [] + pvc_names = [] + _setup_ngem_pv_and_pvcs("job1", "ns1", pv_names, pvc_names) + setup_smb_pv.assert_called_once_with( + "job1-ngem-pv-smb", "archive-creds", "ns1", "//isis.cclrc.ac.uk/Science", [], "ReadWriteMany" + ) + setup_pvc.assert_called_once_with("job1-ngem-pvc", "job1-ngem-pv-smb", "ns1", "ReadWriteMany") + assert pv_names == ["job1-ngem-pv-smb"] + assert pvc_names == ["job1-ngem-pvc"] + + +@mock.patch("jobcreator.job_creator._setup_pvc") +@mock.patch("jobcreator.job_creator._setup_smb_pv") +def test_setup_imat_pv_and_pvcs(setup_smb_pv, setup_pvc): + pv_names = [] + pvc_names = [] + _setup_imat_pv_and_pvcs("job1", "ns1", pv_names, pvc_names) + setup_smb_pv.assert_called_once_with( + "job1-ndximat-pv-smb", + "imat-creds", + "ns1", + "//NDXIMAT.isis.cclrc.ac.uk/data$/", + [], + ) + setup_pvc.assert_called_once_with("job1-ndximat-pvc", "job1-ndximat-pv-smb", "ns1") + assert pv_names == ["job1-ndximat-pv-smb"] + assert pvc_names == ["job1-ndximat-pvc"] + + @mock.patch("jobcreator.job_creator.client") def test_setup_ceph_pv(client): pv_name = mock.MagicMock() @@ -219,6 +300,61 @@ def test_jobcreator_init(mock_load_kubernetes_config): mock_load_kubernetes_config.assert_called_once() +@mock.patch("jobcreator.job_creator._setup_ngem_pv_and_pvcs") +@mock.patch("jobcreator.job_creator._setup_imat_pv_and_pvcs") +@mock.patch("jobcreator.job_creator._setup_extras_pv") +@mock.patch("jobcreator.job_creator._setup_extras_pvc") +@mock.patch("jobcreator.job_creator._setup_smb_pv") +@mock.patch("jobcreator.job_creator._setup_pvc") +@mock.patch("jobcreator.job_creator._setup_ceph_pv") +@mock.patch("jobcreator.job_creator.load_kubernetes_config") +@mock.patch("jobcreator.job_creator.client") +def test_jobcreator_spawn_job_ngem( + client, + _, # noqa: PT019 + setup_ceph_pv, + setup_pvc, + setup_smb_pv, + setup_extras_pvc, + setup_extras_pv, + setup_imat_pv, + setup_ngem_pv, +): + job_name = "test-job" + script = "test-script" + job_namespace = "test-ns" + watcher_sha = "test-sha" + job_creator = JobCreator(watcher_sha, False) + + job_creator.spawn_job( + job_name=job_name, + script=script, + job_namespace=job_namespace, + ceph_creds_k8s_secret_name="some-secret-name", # noqa: S106 + ceph_creds_k8s_namespace="ns", + cluster_id="id", + fs_name="fs", + ceph_mount_path="/path", + job_id=1, + max_time_to_complete_job=100, + fia_api_host="host", + fia_api_api_key="key", + runner_image="image", + manila_share_id="mid", + manila_share_access_id="maid", + special_pvs=["ngem"], + taints=[], + affinity=None, + ) + + setup_ngem_pv.assert_called_once() + # Check that ngem volume and volume mount were added + # We check if V1Volume was called with name="ngem-mount" + assert any(c.kwargs.get("name") == "ngem-mount" for c in client.V1Volume.call_args_list) + # Check if V1VolumeMount was called with name="ngem-mount" + assert any(c.kwargs.get("name") == "ngem-mount" for c in client.V1VolumeMount.call_args_list) + + @mock.patch("jobcreator.job_creator._setup_extras_pv") @mock.patch("jobcreator.job_creator._setup_extras_pvc") @mock.patch("jobcreator.job_creator._setup_smb_pv") diff --git a/job_creator/test/test_main.py b/job_creator/test/test_main.py new file mode 100644 index 0000000..5ee3410 --- /dev/null +++ b/job_creator/test/test_main.py @@ -0,0 +1,294 @@ +import importlib +import os +import unittest +from unittest import mock + +import pytest + +# Mock environment variables and kubernetes config before importing main +with ( + mock.patch.dict( + os.environ, + {"DEFAULT_RUNNER_SHA": "default-sha", "IMAGING_RUNNER_SHA": "imaging-sha", "WATCHER_SHA": "watcher-sha"}, + ), + mock.patch("jobcreator.utils.load_kubernetes_config"), +): + from jobcreator.main import ( + _generate_special_pvs, + _select_runner_image, + _select_taints_and_affinity, + main, + process_autoreduction_message, + process_message, + process_rerun_message, + process_simple_message, + write_readiness_probe_file, + ) + + +import jobcreator.main + +EXPECTED_EXPERIMENT_JOB_ID = 2 +EXPECTED_RERUN_JOB_ID = 100 +EXPECTED_AUTOREDUCTION_JOB_ID = 200 +EXPECTED_AUTO_CALL_COUNT = 2 + + +class TestMain(unittest.TestCase): + def test_generate_special_pvs_imat_with_ngem(self): + additional_values = {"ngem": "true"} + pvs = _generate_special_pvs("imat", additional_values) + assert pvs == ["ngem"] + + def test_generate_special_pvs_imat_without_ngem(self): + additional_values = {"ngem": "false"} + pvs = _generate_special_pvs("imat", additional_values) + assert pvs == ["imat"] + + def test_generate_special_pvs_ines_with_ngem(self): + additional_values = {"ngem": "true"} + pvs = _generate_special_pvs("ines", additional_values) + assert pvs == ["ngem"] + + def test_generate_special_pvs_ines_without_ngem(self): + additional_values = {"ngem": "false"} + pvs = _generate_special_pvs("ines", additional_values) + assert pvs == ["ines"] + + def test_generate_special_pvs_other(self): + additional_values = {"ngem": "true"} + pvs = _generate_special_pvs("other", additional_values) + assert pvs == [] + + def test_select_runner_image_imat_with_ngem(self): + additional_values = {"ngem": "true"} + image = _select_runner_image("imat", additional_values) + assert "default-sha" in image + + def test_select_runner_image_imat_without_ngem(self): + with ( + mock.patch("jobcreator.main.IMAGING_RUNNER_SHA", "imaging-sha"), + mock.patch("jobcreator.main.IMAGING_RUNNER", "ghcr.io/fiaisis/mantidimaging@sha256:imaging-sha"), + ): + additional_values = {"ngem": "false"} + image = _select_runner_image("imat", additional_values) + assert "imaging-sha" in image + + def test_select_runner_image_other(self): + additional_values = {"ngem": "true"} + image = _select_runner_image("other", additional_values) + assert "default-sha" in image + + def test_select_taints_and_affinity_imat_with_ngem(self): + additional_values = {"ngem": "true"} + taints, affinity = _select_taints_and_affinity("imat", additional_values) + assert taints == [] + assert affinity is None + + def test_select_taints_and_affinity_imat_without_ngem(self): + additional_values = {"ngem": "false"} + taints, affinity = _select_taints_and_affinity("imat", additional_values) + assert len(taints) == 1 + assert taints[0]["key"] == "nvidia.com/gpu" + assert affinity is not None + assert affinity["key"] == "node-type" + + def test_select_taints_and_affinity_other(self): + additional_values = {"ngem": "true"} + taints, affinity = _select_taints_and_affinity("other", additional_values) + assert taints == [] + assert affinity is None + + @mock.patch("jobcreator.main.get_job_creator") + @mock.patch("jobcreator.main.create_ceph_mount_path_simple") + @mock.patch("jobcreator.main.find_sha256_of_image") + def test_process_simple_message_user_number(self, mock_find_sha, mock_create_path, mock_get_job_creator): + mock_job_creator = mock_get_job_creator.return_value + mock_find_sha.return_value = "sha256:123" + mock_create_path.return_value = "/ceph/path" + message = { + "runner_image": "image:latest", + "script": "print('hello')", + "user_number": 12345, + "job_id": 1, + "taints": "[]", + "affinity": "{}", + } + process_simple_message(message) + mock_job_creator.spawn_job.assert_called_once() + kwargs = mock_job_creator.spawn_job.call_args.kwargs + assert "run-owner12345-requested-" in kwargs["job_name"] + assert kwargs["script"] == "print('hello')" + assert kwargs["runner_image"] == "sha256:123" + assert kwargs["job_id"] == 1 + + @mock.patch("jobcreator.main.get_job_creator") + @mock.patch("jobcreator.main.create_ceph_mount_path_simple") + @mock.patch("jobcreator.main.find_sha256_of_image") + def test_process_simple_message_experiment_number(self, mock_find_sha, mock_create_path, mock_get_job_creator): + mock_job_creator = mock_get_job_creator.return_value + mock_find_sha.return_value = "sha256:123" + mock_create_path.return_value = "/ceph/path" + message = { + "runner_image": "image:latest", + "script": "print('hello')", + "experiment_number": 67890, + "job_id": 2, + } + process_simple_message(message) + mock_job_creator.spawn_job.assert_called_once() + kwargs = mock_job_creator.spawn_job.call_args.kwargs + assert "run-owner67890-requested-" in kwargs["job_name"] + assert kwargs["job_id"] == EXPECTED_EXPERIMENT_JOB_ID + + @mock.patch("jobcreator.main.logger") + def test_process_simple_message_invalid_job_id(self, mock_logger): + message = {"runner_image": "image:latest", "script": "print('hello')", "job_id": "not-an-int"} + process_simple_message(message) + mock_logger.exception.assert_called_once() + + @mock.patch("jobcreator.main.get_job_creator") + @mock.patch("jobcreator.main.create_ceph_mount_path_autoreduction") + @mock.patch("jobcreator.main.find_sha256_of_image") + def test_process_rerun_message(self, mock_find_sha, mock_create_path, mock_get_job_creator): + mock_job_creator = mock_get_job_creator.return_value + mock_find_sha.return_value = "sha256:rerun" + mock_create_path.return_value = "/ceph/autoreduction" + message = { + "runner_image": "image:latest", + "script": "rerun script", + "instrument": "imat", + "rb_number": "12345", + "filename": "data.nxs", + "job_id": 100, + } + process_rerun_message(message) + mock_job_creator.spawn_job.assert_called_once() + kwargs = mock_job_creator.spawn_job.call_args.kwargs + assert "run-data.nxs-" in kwargs["job_name"] + assert kwargs["job_id"] == EXPECTED_RERUN_JOB_ID + assert kwargs["special_pvs"] == ["imat"] + + @mock.patch("jobcreator.main.get_job_creator") + @mock.patch("jobcreator.main.post_autoreduction_job") + @mock.patch("jobcreator.main.create_ceph_mount_path_autoreduction") + @mock.patch("jobcreator.main.find_sha256_of_image") + def test_process_autoreduction_message(self, mock_find_sha, mock_create_path, mock_post_job, mock_get_job_creator): + mock_job_creator = mock_get_job_creator.return_value + mock_find_sha.return_value = "sha256:auto" + mock_create_path.return_value = "/ceph/auto" + mock_post_job.return_value = ("generated_script", 200) + message = { + "filepath": "/path/to/data.nxs", + "experiment_number": "67890", + "instrument": "imat", + "experiment_title": "test title", + "users": "user1", + "run_start": "start", + "run_end": "end", + "good_frames": 100, + "raw_frames": 110, + "additional_values": {"ngem": "false"}, + } + process_autoreduction_message(message) + mock_job_creator.spawn_job.assert_called_once() + kwargs = mock_job_creator.spawn_job.call_args.kwargs + assert "run-data-" in kwargs["job_name"] + assert kwargs["job_id"] == EXPECTED_AUTOREDUCTION_JOB_ID + assert kwargs["script"] == "generated_script" + + @mock.patch("jobcreator.main.process_simple_message") + @mock.patch("jobcreator.main.process_rerun_message") + @mock.patch("jobcreator.main.process_autoreduction_message") + def test_process_message(self, mock_auto, mock_rerun, mock_simple): + process_message({"job_type": "simple"}) + mock_simple.assert_called_once() + process_message({"job_type": "rerun"}) + mock_rerun.assert_called_once() + process_message({"job_type": "autoreduction"}) + mock_auto.assert_called_once() + process_message({}) # defaults to autoreduction + assert mock_auto.call_count == EXPECTED_AUTO_CALL_COUNT + + @mock.patch("jobcreator.main.time") + @mock.patch("jobcreator.main.Path") + def test_write_readiness_probe_file(self, mock_path, mock_time): + mock_file = mock.MagicMock() + mock_path.return_value.open.return_value.__enter__.return_value = mock_file + mock_time.strftime.return_value = "2023-01-01 00:00:00" + + write_readiness_probe_file() + + mock_file.write.assert_called_once_with("2023-01-01 00:00:00") + + @mock.patch("jobcreator.main.QueueConsumer") + def test_main(self, mock_consumer): + main() + mock_consumer.assert_called_once() + mock_consumer.return_value.start_consuming.assert_called_once() + + def test_select_runner_image_imat_missing_sha(self): + with mock.patch("jobcreator.main.IMAGING_RUNNER_SHA", None): + image = _select_runner_image("imat", {}) + assert "default-sha" in image + + @mock.patch("jobcreator.main.logger") + def test_process_rerun_message_exception(self, mock_logger): + # This should trigger a KeyError since 'runner_image' is missing + process_rerun_message({}) + mock_logger.exception.assert_called_once() + + @mock.patch("jobcreator.main.logger") + def test_process_autoreduction_message_exception(self, mock_logger): + # This should trigger a KeyError since 'filepath' is missing + process_autoreduction_message({}) + mock_logger.exception.assert_called_once() + + @mock.patch("jobcreator.main.logger") + def test_process_message_unrecognised(self, mock_logger): + process_message({"job_type": "unknown"}) + mock_logger.warn.assert_called_once() + + @mock.patch("jobcreator.main.main") + def test_main_coverage(self, mock_main): + # Test missing DEFAULT_RUNNER_SHA + with ( + mock.patch.dict(os.environ, {"WATCHER_SHA": "watcher-sha"}, clear=True), + mock.patch("jobcreator.utils.load_kubernetes_config"), + pytest.raises(OSError, match="DEFAULT_RUNNER_SHA"), + ): + importlib.reload(jobcreator.main) + + # Test missing WATCHER_SHA + with ( + mock.patch.dict(os.environ, {"DEFAULT_RUNNER_SHA": "default-sha"}, clear=True), + mock.patch("jobcreator.utils.load_kubernetes_config"), + pytest.raises(OSError, match="WATCHER_SHA"), + ): + importlib.reload(jobcreator.main) + + # Test DEV_MODE branch + with ( + mock.patch.dict( + os.environ, {"DEFAULT_RUNNER_SHA": "default-sha", "WATCHER_SHA": "watcher-sha", "DEV_MODE": "True"} + ), + mock.patch("jobcreator.utils.load_kubernetes_config"), + ): + importlib.reload(jobcreator.main) + assert jobcreator.main.DEV_MODE + + # Test __name__ == "__main__" block + with ( + mock.patch.dict(os.environ, {"DEFAULT_RUNNER_SHA": "default-sha", "WATCHER_SHA": "watcher-sha"}), + mock.patch("jobcreator.utils.load_kubernetes_config"), + ): + # Use reload to trigger imports and most lines + with mock.patch("jobcreator.main.__name__", "__main__"), mock.patch("jobcreator.main.main"): + importlib.reload(jobcreator.main) + + # Explicitly execute the if __name__ == "__main__": line and the main() call + # We already have mock_main from the decorator + source = "if __name__ == '__main__': main()" + exec_globals = {"main": mock_main, "__name__": "__main__"} + exec(source, exec_globals) # noqa: S102 + mock_main.assert_called_once()