Skip to content
34 changes: 29 additions & 5 deletions job_creator/jobcreator/job_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,23 @@
from jobcreator.utils import load_kubernetes_config, logger


def _setup_smb_pv(pv_name: str, secret_name: str, secret_namespace: str, source: str, mount_options: list[str]) -> None:
def _setup_smb_pv(
pv_name: str,
secret_name: str,
secret_namespace: str,
source: str,
mount_options: list[str],
access_mode: str = "ReadOnlyMany",
) -> None:
"""
Sets up an smb PV using the loaded kubeconfig as a destination
:param pv_name: str, The name given to the smb-pv when it's made
:param secret_name: str, The name of the secret that contains the credentials for the smb share
:param secret_namespace: str, the namespace of the secret
:param source: str, The IP/url/uri that is used to mount the smb share
:param mount_options: list, The mount options for the smb share
:return: str, the name of the archive PV
:param access_mode: str, The access mode for the PV. Defaults to "ReadOnlyMany"
:return: str, the name of the PV
"""
metadata = client.V1ObjectMeta(name=pv_name, annotations={"pv.kubernetes.io/provisioned-by": "smb.csi.k8s.io"})
secret_ref = client.V1SecretReference(name=secret_name, namespace=secret_namespace)
Expand All @@ -30,13 +38,13 @@ def _setup_smb_pv(pv_name: str, secret_name: str, secret_namespace: str, source:
)
spec = client.V1PersistentVolumeSpec(
capacity={"storage": "1000Gi"},
access_modes=["ReadOnlyMany"],
access_modes=[access_mode],
persistent_volume_reclaim_policy="Retain",
mount_options=mount_options,
csi=csi,
)
archive_pv = client.V1PersistentVolume(api_version="v1", kind="PersistentVolume", metadata=metadata, spec=spec)
client.CoreV1Api().create_persistent_volume(archive_pv)
pv = client.V1PersistentVolume(api_version="v1", kind="PersistentVolume", metadata=metadata, spec=spec)
client.CoreV1Api().create_persistent_volume(pv)


def _setup_pvc(pvc_name: str, pv_name: str, namespace: str, access_mode: str = "ReadOnlyMany") -> None:
Expand Down Expand Up @@ -163,6 +171,15 @@ def _setup_ceph_pv(
return pv_name


def _setup_ngem_pv_and_pvcs(job_name: str, namespace: str, pv_names: list[str], pvc_names: list[str]) -> None:
ngem_pv_name = f"{job_name}-ngem-pv-smb"
ngem_pvc_name = f"{job_name}-ngem-pvc"
_setup_smb_pv(ngem_pv_name, "archive-creds", namespace, "//isis.cclrc.ac.uk/Science", [], "ReadWriteMany")
_setup_pvc(ngem_pvc_name, ngem_pv_name, namespace, "ReadWriteMany")
pv_names.append(ngem_pv_name)
pvc_names.append(ngem_pvc_name)


def _setup_imat_pv_and_pvcs(job_name: str, namespace: str, pv_names: list[str], pvc_names: list[str]) -> None:
imat_pv_name = f"{job_name}-ndximat-pv-smb"
imat_pvc_name = f"{job_name}-ndximat-pvc"
Expand Down Expand Up @@ -385,6 +402,13 @@ def spawn_job( # noqa: PLR0913
client.V1VolumeMount(name="extras-mount", mount_path="/extras"),
]
# Setup special PVs and add them to the volume mounts
if "ngem" in special_pvs:
_setup_ngem_pv_and_pvcs(job_name, job_namespace, pv_names, pvc_names)
ngem_pvc_source = client.V1PersistentVolumeClaimVolumeSource(
claim_name=f"{job_name}-ngem-pvc", read_only=False
)
volumes.append(client.V1Volume(name="ngem-mount", persistent_volume_claim=ngem_pvc_source))
volumes_mounts.append(client.V1VolumeMount(name="ngem-mount", mount_path="/ngem"))
if "imat" in special_pvs:
_setup_imat_pv_and_pvcs(job_name, job_namespace, pv_names, pvc_names)
imat_pvc_source = client.V1PersistentVolumeClaimVolumeSource(
Expand Down
65 changes: 49 additions & 16 deletions job_creator/jobcreator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,19 @@
CONSUMER_PASSWORD = os.environ.get("QUEUE_PASSWORD", "")
REDUCE_USER_ID = os.environ.get("REDUCE_USER_ID", "")
JOB_NAMESPACE = os.environ.get("JOB_NAMESPACE", "fia")
JOB_CREATOR = JobCreator(dev_mode=DEV_MODE, watcher_sha=WATCHER_SHA)
JOB_CREATOR: JobCreator | None = None


def get_job_creator() -> JobCreator:
Comment thread
Pasarus marked this conversation as resolved.
"""Use this Singleton pattern to allow for trivial mocking within testing infrastructure, not strictly needed but
trivialises test implementation for mocking out the JobCreator."""
global JOB_CREATOR # noqa: PLW0603
if JOB_CREATOR is None:
if WATCHER_SHA is None:
raise OSError("WATCHER_SHA not set in the environment, please add it.")
JOB_CREATOR = JobCreator(dev_mode=DEV_MODE, watcher_sha=WATCHER_SHA)
return JOB_CREATOR


CEPH_CREDS_SECRET_NAME = os.environ.get("CEPH_CREDS_SECRET_NAME", "ceph-creds")
CEPH_CREDS_SECRET_NAMESPACE = os.environ.get("CEPH_CREDS_SECRET_NAMESPACE", "fia")
Expand All @@ -59,7 +71,7 @@
MAX_TIME_TO_COMPLETE = int(os.environ.get("MAX_TIME_TO_COMPLETE", str(60 * 60 * 6)))


def _generate_special_pvs(instrument: str) -> list[str]:
def _generate_special_pvs(instrument: str, additional_values: dict[str, Any]) -> list[str]:
"""
A generic function for, based on passed args, returning what the special persistent volumes should be.
"""
Expand All @@ -68,19 +80,31 @@ def _generate_special_pvs(instrument: str) -> list[str]:
match instrument.lower():
case "imat":
logger.info("Special PV for %s added.", instrument)
special_pvs.append("imat")
if "ngem" in additional_values and additional_values["ngem"] == "true":
special_pvs.append("ngem")
else:
special_pvs.append("imat")
case "ines":
logger.info("Special PV for %s added.", instrument)
if "ngem" in additional_values and additional_values["ngem"] == "true":
special_pvs.append("ngem")
else:
special_pvs.append("ines")
case _:
logger.info("No special PV needed for %s", instrument)

return special_pvs


def _select_runner_image(instrument: str) -> str:
def _select_runner_image(instrument: str, additional_values: dict[str, Any]) -> str:
"""
A generic function for, based on passed args, returning what the runner that should be used.
"""
match instrument.lower():
case "imat":
if "ngem" in additional_values and additional_values["ngem"] == "true":
# For ngem we want to return the default mantid runner. INES always wants mantid default runner.
return DEFAULT_RUNNER
if IMAGING_RUNNER_SHA is not None:
logger.info("Imaging runner image selected for %s ", instrument)
return IMAGING_RUNNER
Expand All @@ -91,7 +115,9 @@ def _select_runner_image(instrument: str) -> str:
return DEFAULT_RUNNER


def _select_taints_and_affinity(instrument: str) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
def _select_taints_and_affinity(
instrument: str, additional_values: dict[str, Any]
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
"""
A generic function for, based on passed args, returning what the runner that should be used.
"""
Expand All @@ -100,9 +126,10 @@ def _select_taints_and_affinity(instrument: str) -> tuple[list[dict[str, Any]],

match instrument.lower():
case "imat":
logger.info("Applying taint to the job on instrument %s", instrument)
taints.append({"key": "nvidia.com/gpu", "effect": "NoSchedule", "operator": "Exists"})
affinity = {"key": "node-type", "operator": "In", "values": ["gpu-worker"]}
if "ngem" not in additional_values or additional_values["ngem"] != "true":
logger.info("Applying taint to the job on instrument %s", instrument)
taints.append({"key": "nvidia.com/gpu", "effect": "NoSchedule", "operator": "Exists"})
affinity = {"key": "node-type", "operator": "In", "values": ["gpu-worker"]}
case _:
logger.info("No taints applied to %s runners", instrument)

Expand Down Expand Up @@ -146,7 +173,7 @@ def process_simple_message(message: dict[str, Any]) -> None:
{"user_number": str(user_number)} if user_number else {"experiment_number": str(experiment_number)}
)
ceph_mount_path = create_ceph_mount_path_simple(**ceph_mount_path_kwargs)
JOB_CREATOR.spawn_job(
get_job_creator().spawn_job(
job_name=job_name,
script=script,
job_namespace=JOB_NAMESPACE,
Expand Down Expand Up @@ -185,12 +212,16 @@ def process_rerun_message(message: dict[str, Any]) -> None:
rb_number=str(message["rb_number"]),
)

special_pvs = _generate_special_pvs(instrument=message["instrument"])
taints, affinity = _select_taints_and_affinity(instrument=message["instrument"])
special_pvs = _generate_special_pvs(
instrument=message["instrument"], additional_values=message.get("additional_values", {})
)
taints, affinity = _select_taints_and_affinity(
instrument=message["instrument"], additional_values=message.get("additional_values", {})
)

# Add UUID which will avoid collisions for reruns
job_name = f"run-{str(message['filename']).lower()}-{uuid.uuid4().hex!s}"
JOB_CREATOR.spawn_job(
get_job_creator().spawn_job(
job_name=job_name,
script=script,
job_namespace=JOB_NAMESPACE,
Expand Down Expand Up @@ -226,7 +257,7 @@ def process_autoreduction_message(message: dict[str, Any]) -> None:
instrument_name = message["instrument"]
runner_image = message.get("runner_image")
if runner_image is None:
runner_image = _select_runner_image(instrument_name)
runner_image = _select_runner_image(instrument_name, message["additional_values"])
runner_image = find_sha256_of_image(runner_image)
autoreduction_request = {
"filename": filename,
Expand All @@ -242,8 +273,10 @@ def process_autoreduction_message(message: dict[str, Any]) -> None:
"runner_image": runner_image,
}

special_pvs = _generate_special_pvs(instrument=instrument_name)
taints, affinity = _select_taints_and_affinity(instrument=message["instrument"])
special_pvs = _generate_special_pvs(instrument=instrument_name, additional_values=message["additional_values"])
taints, affinity = _select_taints_and_affinity(
instrument=message["instrument"], additional_values=message["additional_values"]
)

# Add UUID which will avoid collisions for reruns
job_name = f"run-{filename.lower()}-{uuid.uuid4().hex!s}"
Expand All @@ -253,7 +286,7 @@ def process_autoreduction_message(message: dict[str, Any]) -> None:
autoreduction_request=autoreduction_request,
)
ceph_mount_path = create_ceph_mount_path_autoreduction(instrument_name, rb_number)
JOB_CREATOR.spawn_job(
get_job_creator().spawn_job(
job_name=job_name,
script=script,
job_namespace=JOB_NAMESPACE,
Expand Down
136 changes: 136 additions & 0 deletions job_creator/test/test_job_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

from jobcreator.job_creator import (
JobCreator,
_generate_affinities,
_generate_tolerations_from_taints,
_setup_ceph_pv,
_setup_extras_pv,
_setup_extras_pvc,
_setup_imat_pv_and_pvcs,
_setup_ngem_pv_and_pvcs,
_setup_pvc,
_setup_smb_pv,
)
Expand Down Expand Up @@ -155,6 +159,83 @@ def test_setup_extras_pv(client):
client.V1SecretReference.assert_called_once_with(name="manila-creds", namespace=secret_namespace)


EXPECTED_TOLERATIONS_COUNT = 2


@mock.patch("jobcreator.job_creator.client")
def test_generate_tolerations_from_taints(client):
taints = [
{"key": "key1", "value": "value1", "operator": "Equal", "effect": "NoSchedule"},
{"key": "key2", "operator": "Exists", "effect": "NoExecute"},
]
tolerations = _generate_tolerations_from_taints(taints)

assert len(tolerations) == EXPECTED_TOLERATIONS_COUNT
client.V1Toleration.assert_has_calls(
[
call(key="key1", value="value1", operator="Equal", effect="NoSchedule"),
call(key="key2", value=None, operator="Exists", effect="NoExecute"),
]
)


@mock.patch("jobcreator.job_creator.client")
def test_generate_affinities_none(client):
affinity = _generate_affinities(None)
assert affinity == client.V1Affinity.return_value
client.V1Affinity.assert_called_once_with(pod_anti_affinity=client.V1PodAntiAffinity.return_value)


@mock.patch("jobcreator.job_creator.logger")
@mock.patch("jobcreator.job_creator.client")
def test_generate_affinities_missing_key(client, logger):
node_affinity_dict = {"key": "some-key", "operator": "In"} # missing "values"
_generate_affinities(node_affinity_dict)
logger.error.assert_called_once()
client.V1Affinity.assert_called_once_with(pod_anti_affinity=client.V1PodAntiAffinity.return_value)


@mock.patch("jobcreator.job_creator.client")
def test_generate_affinities_valid(client):
node_affinity_dict = {"key": "some-key", "operator": "In", "values": ["val1"]}
_generate_affinities(node_affinity_dict)
client.V1Affinity.assert_called_once_with(
pod_anti_affinity=client.V1PodAntiAffinity.return_value, node_affinity=client.V1NodeAffinity.return_value
)


@mock.patch("jobcreator.job_creator._setup_pvc")
@mock.patch("jobcreator.job_creator._setup_smb_pv")
def test_setup_ngem_pv_and_pvcs(setup_smb_pv, setup_pvc):
pv_names = []
pvc_names = []
_setup_ngem_pv_and_pvcs("job1", "ns1", pv_names, pvc_names)
setup_smb_pv.assert_called_once_with(
"job1-ngem-pv-smb", "archive-creds", "ns1", "//isis.cclrc.ac.uk/Science", [], "ReadWriteMany"
)
setup_pvc.assert_called_once_with("job1-ngem-pvc", "job1-ngem-pv-smb", "ns1", "ReadWriteMany")
assert pv_names == ["job1-ngem-pv-smb"]
assert pvc_names == ["job1-ngem-pvc"]


@mock.patch("jobcreator.job_creator._setup_pvc")
@mock.patch("jobcreator.job_creator._setup_smb_pv")
def test_setup_imat_pv_and_pvcs(setup_smb_pv, setup_pvc):
pv_names = []
pvc_names = []
_setup_imat_pv_and_pvcs("job1", "ns1", pv_names, pvc_names)
setup_smb_pv.assert_called_once_with(
"job1-ndximat-pv-smb",
"imat-creds",
"ns1",
"//NDXIMAT.isis.cclrc.ac.uk/data$/",
[],
)
setup_pvc.assert_called_once_with("job1-ndximat-pvc", "job1-ndximat-pv-smb", "ns1")
assert pv_names == ["job1-ndximat-pv-smb"]
assert pvc_names == ["job1-ndximat-pvc"]


@mock.patch("jobcreator.job_creator.client")
def test_setup_ceph_pv(client):
pv_name = mock.MagicMock()
Expand Down Expand Up @@ -219,6 +300,61 @@ def test_jobcreator_init(mock_load_kubernetes_config):
mock_load_kubernetes_config.assert_called_once()


@mock.patch("jobcreator.job_creator._setup_ngem_pv_and_pvcs")
@mock.patch("jobcreator.job_creator._setup_imat_pv_and_pvcs")
@mock.patch("jobcreator.job_creator._setup_extras_pv")
@mock.patch("jobcreator.job_creator._setup_extras_pvc")
@mock.patch("jobcreator.job_creator._setup_smb_pv")
@mock.patch("jobcreator.job_creator._setup_pvc")
@mock.patch("jobcreator.job_creator._setup_ceph_pv")
@mock.patch("jobcreator.job_creator.load_kubernetes_config")
@mock.patch("jobcreator.job_creator.client")
def test_jobcreator_spawn_job_ngem(
client,
_, # noqa: PT019
setup_ceph_pv,
setup_pvc,
setup_smb_pv,
setup_extras_pvc,
setup_extras_pv,
setup_imat_pv,
setup_ngem_pv,
):
job_name = "test-job"
script = "test-script"
job_namespace = "test-ns"
watcher_sha = "test-sha"
job_creator = JobCreator(watcher_sha, False)

job_creator.spawn_job(
job_name=job_name,
script=script,
job_namespace=job_namespace,
ceph_creds_k8s_secret_name="some-secret-name", # noqa: S106
ceph_creds_k8s_namespace="ns",
cluster_id="id",
fs_name="fs",
ceph_mount_path="/path",
job_id=1,
max_time_to_complete_job=100,
fia_api_host="host",
fia_api_api_key="key",
runner_image="image",
manila_share_id="mid",
manila_share_access_id="maid",
special_pvs=["ngem"],
taints=[],
affinity=None,
)

setup_ngem_pv.assert_called_once()
# Check that ngem volume and volume mount were added
# We check if V1Volume was called with name="ngem-mount"
assert any(c.kwargs.get("name") == "ngem-mount" for c in client.V1Volume.call_args_list)
# Check if V1VolumeMount was called with name="ngem-mount"
assert any(c.kwargs.get("name") == "ngem-mount" for c in client.V1VolumeMount.call_args_list)


@mock.patch("jobcreator.job_creator._setup_extras_pv")
@mock.patch("jobcreator.job_creator._setup_extras_pvc")
@mock.patch("jobcreator.job_creator._setup_smb_pv")
Expand Down
Loading