diff --git a/.gitignore b/.gitignore index eef1052fe..fbb31b2b9 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ Pipfile.lock poetry.lock .venv* build/ +tls-cluster-namespace +quicktest.yaml diff --git a/src/codeflare_sdk/cluster/auth.py b/src/codeflare_sdk/cluster/auth.py index 33ad8cf7d..85db3d61d 100644 --- a/src/codeflare_sdk/cluster/auth.py +++ b/src/codeflare_sdk/cluster/auth.py @@ -20,8 +20,12 @@ """ import abc -import openshift as oc -from openshift import OpenShiftPythonException +from kubernetes import client, config + +global api_client +api_client = None +global config_path +config_path = None class Authentication(metaclass=abc.ABCMeta): @@ -43,80 +47,131 @@ def logout(self): pass +class KubeConfiguration(metaclass=abc.ABCMeta): + """ + An abstract class that defines the method for loading a user defined config file using the `load_kube_config()` function + """ + + def load_kube_config(self): + """ + Method for setting your Kubernetes configuration to a certain file + """ + pass + + def logout(self): + """ + Method for logging out of the remote cluster + """ + pass + + class TokenAuthentication(Authentication): """ - `TokenAuthentication` is a subclass of `Authentication`. It can be used to authenticate to an OpenShift + `TokenAuthentication` is a subclass of `Authentication`. It can be used to authenticate to a Kubernetes cluster when the user has an API token and the API server address. """ - def __init__(self, token: str = None, server: str = None, skip_tls: bool = False): + def __init__( + self, + token: str, + server: str, + skip_tls: bool = False, + ca_cert_path: str = None, + ): """ Initialize a TokenAuthentication object that requires a value for `token`, the API Token - and `server`, the API server address for authenticating to an OpenShift cluster. + and `server`, the API server address for authenticating to a Kubernetes cluster. """ self.token = token self.server = server self.skip_tls = skip_tls + self.ca_cert_path = ca_cert_path def login(self) -> str: """ - This function is used to login to an OpenShift cluster using the user's API token and API server address. - Depending on the cluster, a user can choose to login in with "--insecure-skip-tls-verify` by setting `skip_tls` - to `True`. + This function is used to log in to a Kubernetes cluster using the user's API token and API server address. + Depending on the cluster, a user can choose to login in with `--insecure-skip-tls-verify` by setting `skip_tls` + to `True` or `--certificate-authority` by setting `skip_tls` to False and providing a path to a ca bundle with `ca_cert_path`. """ - args = [f"--token={self.token}", f"--server={self.server}"] - if self.skip_tls: - args.append("--insecure-skip-tls-verify") + global config_path + global api_client try: - response = oc.invoke("login", args) - except OpenShiftPythonException as osp: # pragma: no cover - error_msg = osp.result.err() - if "The server uses a certificate signed by unknown authority" in error_msg: - return "Error: certificate auth failure, please set `skip_tls=True` in TokenAuthentication" - elif "invalid" in error_msg: - raise PermissionError(error_msg) + configuration = client.Configuration() + configuration.api_key_prefix["authorization"] = "Bearer" + configuration.host = self.server + configuration.api_key["authorization"] = self.token + if self.skip_tls == False and self.ca_cert_path == None: + configuration.verify_ssl = True + elif self.skip_tls == False: + configuration.ssl_ca_cert = self.ca_cert_path else: - return error_msg - return response.out() + configuration.verify_ssl = False + api_client = client.ApiClient(configuration) + client.AuthenticationApi(api_client).get_api_group() + config_path = None + return "Logged into %s" % self.server + except client.ApiException: # pragma: no cover + api_client = None + print("Authentication Error please provide the correct token + server") def logout(self) -> str: """ - This function is used to logout of an OpenShift cluster. + This function is used to logout of a Kubernetes cluster. """ - args = [f"--token={self.token}", f"--server={self.server}"] - response = oc.invoke("logout", args) - return response.out() + global config_path + config_path = None + global api_client + api_client = None + return "Successfully logged out of %s" % self.server -class PasswordUserAuthentication(Authentication): +class KubeConfigFileAuthentication(KubeConfiguration): """ - `PasswordUserAuthentication` is a subclass of `Authentication`. It can be used to authenticate to an OpenShift - cluster when the user has a username and password. + A class that defines the necessary methods for passing a user's own Kubernetes config file. + Specifically this class defines the `load_kube_config()` and `config_check()` functions. """ - def __init__( - self, - username: str = None, - password: str = None, - ): - """ - Initialize a PasswordUserAuthentication object that requires a value for `username` - and `password` for authenticating to an OpenShift cluster. - """ - self.username = username - self.password = password + def __init__(self, kube_config_path: str = None): + self.kube_config_path = kube_config_path - def login(self) -> str: + def load_kube_config(self): """ - This function is used to login to an OpenShift cluster using the user's `username` and `password`. + Function for loading a user's own predefined Kubernetes config file. """ - response = oc.login(self.username, self.password) - return response.out() + global config_path + global api_client + try: + if self.kube_config_path == None: + return "Please specify a config file path" + config_path = self.kube_config_path + api_client = None + config.load_kube_config(config_path) + response = "Loaded user config file at path %s" % self.kube_config_path + except config.ConfigException: # pragma: no cover + config_path = None + raise Exception("Please specify a config file path") + return response + + +def config_check() -> str: + """ + Function for loading the config file at the default config location ~/.kube/config if the user has not + specified their own config file or has logged in with their token and server. + """ + global config_path + global api_client + if config_path == None and api_client == None: + config.load_kube_config() + if config_path != None and api_client == None: + return config_path - def logout(self) -> str: - """ - This function is used to logout of an OpenShift cluster. - """ - response = oc.invoke("logout") - return response.out() + +def api_config_handler() -> str: + """ + This function is used to load the api client if the user has logged in + """ + if api_client != None and config_path == None: + return api_client + else: + return None diff --git a/src/codeflare_sdk/cluster/awload.py b/src/codeflare_sdk/cluster/awload.py index ecf432133..12544ebac 100644 --- a/src/codeflare_sdk/cluster/awload.py +++ b/src/codeflare_sdk/cluster/awload.py @@ -24,6 +24,7 @@ from kubernetes import client, config from ..utils.kube_api_helpers import _kube_api_error_handling +from .auth import config_check, api_config_handler class AWManager: @@ -57,8 +58,8 @@ def submit(self) -> None: Attempts to create the AppWrapper custom resource using the yaml file """ try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) api_instance.create_namespaced_custom_object( group="mcad.ibm.com", version="v1beta1", @@ -82,8 +83,8 @@ def remove(self) -> None: return try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) api_instance.delete_namespaced_custom_object( group="mcad.ibm.com", version="v1beta1", diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 6fb57abbe..ff92bfcf0 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -23,6 +23,7 @@ from ray.job_submission import JobSubmissionClient +from .auth import config_check, api_config_handler from ..utils import pretty_print from ..utils.generate_yaml import generate_appwrapper from ..utils.kube_api_helpers import _kube_api_error_handling @@ -35,8 +36,8 @@ RayClusterStatus, ) from kubernetes import client, config - import yaml +import os class Cluster: @@ -68,7 +69,9 @@ def create_app_wrapper(self): if self.config.namespace is None: self.config.namespace = get_current_namespace() - if type(self.config.namespace) is not str: + if self.config.namespace is None: + print("Please specify with namespace=") + elif type(self.config.namespace) is not str: raise TypeError( f"Namespace {self.config.namespace} is of type {type(self.config.namespace)}. Check your Kubernetes Authentication." ) @@ -114,8 +117,8 @@ def up(self): """ namespace = self.config.namespace try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) with open(self.app_wrapper_yaml) as f: aw = yaml.load(f, Loader=yaml.FullLoader) api_instance.create_namespaced_custom_object( @@ -135,8 +138,8 @@ def down(self): """ namespace = self.config.namespace try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) api_instance.delete_namespaced_custom_object( group="mcad.ibm.com", version="v1beta1", @@ -247,8 +250,8 @@ def cluster_dashboard_uri(self) -> str: Returns a string containing the cluster's dashboard URI. """ try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) routes = api_instance.list_namespaced_custom_object( group="route.openshift.io", version="v1", @@ -376,15 +379,29 @@ def list_all_queued(namespace: str, print_to_console: bool = True): def get_current_namespace(): # pragma: no cover - try: - config.load_kube_config() - _, active_context = config.list_kube_config_contexts() - except Exception as e: - return _kube_api_error_handling(e) - try: - return active_context["context"]["namespace"] - except KeyError: - return "default" + if api_config_handler() != None: + if os.path.isfile("/var/run/secrets/kubernetes.io/serviceaccount/namespace"): + try: + file = open( + "/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r" + ) + active_context = file.readline().strip("\n") + return active_context + except Exception as e: + print("Unable to find current namespace") + return None + else: + print("Unable to find current namespace") + return None + else: + try: + _, active_context = config.list_kube_config_contexts(config_check()) + except Exception as e: + return _kube_api_error_handling(e) + try: + return active_context["context"]["namespace"] + except KeyError: + return None def get_cluster(cluster_name: str, namespace: str = "default"): @@ -423,8 +440,8 @@ def _get_ingress_domain(): def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]: try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) aws = api_instance.list_namespaced_custom_object( group="mcad.ibm.com", version="v1beta1", @@ -442,8 +459,8 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]: def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]: try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1alpha1", @@ -462,8 +479,8 @@ def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]: def _get_ray_clusters(namespace="default") -> List[RayCluster]: list_of_clusters = [] try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1alpha1", @@ -484,8 +501,8 @@ def _get_app_wrappers( list_of_app_wrappers = [] try: - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) aws = api_instance.list_namespaced_custom_object( group="mcad.ibm.com", version="v1beta1", @@ -511,8 +528,8 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]: else: status = RayClusterStatus.UNKNOWN - config.load_kube_config() - api_instance = client.CustomObjectsApi() + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) routes = api_instance.list_namespaced_custom_object( group="route.openshift.io", version="v1", diff --git a/src/codeflare_sdk/utils/generate_cert.py b/src/codeflare_sdk/utils/generate_cert.py index 2d73621b8..04b04d3e0 100644 --- a/src/codeflare_sdk/utils/generate_cert.py +++ b/src/codeflare_sdk/utils/generate_cert.py @@ -19,6 +19,7 @@ from cryptography import x509 from cryptography.x509.oid import NameOID import datetime +from ..cluster.auth import config_check, api_config_handler from kubernetes import client, config @@ -82,8 +83,8 @@ def generate_tls_cert(cluster_name, namespace, days=30): # Similar to: # oc get secret ca-secret- -o template='{{index .data "ca.key"}}' # oc get secret ca-secret- -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt - config.load_kube_config() - v1 = client.CoreV1Api() + config_check() + v1 = client.CoreV1Api(api_config_handler()) secret = v1.read_namespaced_secret(f"ca-secret-{cluster_name}", namespace).data ca_cert = secret.get("ca.crt") ca_key = secret.get("ca.key") diff --git a/tests/unit_test.py b/tests/unit_test.py index 57d606e49..21c1adf24 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -21,7 +21,7 @@ parent = Path(__file__).resolve().parents[1] sys.path.append(str(parent) + "/src") -from kubernetes import client +from kubernetes import client, config from codeflare_sdk.cluster.awload import AWManager from codeflare_sdk.cluster.cluster import ( Cluster, @@ -35,8 +35,8 @@ ) from codeflare_sdk.cluster.auth import ( TokenAuthentication, - PasswordUserAuthentication, Authentication, + KubeConfigFileAuthentication, ) from codeflare_sdk.utils.pretty_print import ( print_no_resources_found, @@ -65,7 +65,6 @@ ) import openshift -from openshift import OpenShiftPythonException from openshift.selector import Selector import ray from torchx.specs import AppDryRunInfo, AppDef @@ -89,120 +88,79 @@ def att_side_effect(self): return self.high_level_operation -def att_side_effect_tls(self): - if "--insecure-skip-tls-verify" in self.high_level_operation[1]: - return self.high_level_operation - else: - raise OpenShiftPythonException( - "The server uses a certificate signed by unknown authority" - ) - - def test_token_auth_creation(): try: - token_auth = TokenAuthentication() - assert token_auth.token == None - assert token_auth.server == None - assert token_auth.skip_tls == False - - token_auth = TokenAuthentication("token") - assert token_auth.token == "token" - assert token_auth.server == None - assert token_auth.skip_tls == False - - token_auth = TokenAuthentication("token", "server") + token_auth = TokenAuthentication(token="token", server="server") assert token_auth.token == "token" assert token_auth.server == "server" assert token_auth.skip_tls == False + assert token_auth.ca_cert_path == None - token_auth = TokenAuthentication("token", server="server") + token_auth = TokenAuthentication(token="token", server="server", skip_tls=True) assert token_auth.token == "token" assert token_auth.server == "server" - assert token_auth.skip_tls == False + assert token_auth.skip_tls == True + assert token_auth.ca_cert_path == None - token_auth = TokenAuthentication(token="token", server="server") + token_auth = TokenAuthentication(token="token", server="server", skip_tls=False) assert token_auth.token == "token" assert token_auth.server == "server" assert token_auth.skip_tls == False + assert token_auth.ca_cert_path == None - token_auth = TokenAuthentication(token="token", server="server", skip_tls=True) + token_auth = TokenAuthentication( + token="token", server="server", skip_tls=False, ca_cert_path="path/to/cert" + ) assert token_auth.token == "token" assert token_auth.server == "server" - assert token_auth.skip_tls == True + assert token_auth.skip_tls == False + assert token_auth.ca_cert_path == "path/to/cert" except Exception: assert 0 == 1 def test_token_auth_login_logout(mocker): - mocker.patch("openshift.invoke", side_effect=arg_side_effect) - mock_res = mocker.patch.object(openshift.Result, "out") - mock_res.side_effect = lambda: att_side_effect(fake_res) + mocker.patch.object(client, "ApiClient") - token_auth = TokenAuthentication(token="testtoken", server="testserver:6443") - assert token_auth.login() == ( - "login", - ["--token=testtoken", "--server=testserver:6443"], - ) - assert token_auth.logout() == ( - "logout", - ["--token=testtoken", "--server=testserver:6443"], + token_auth = TokenAuthentication( + token="testtoken", server="testserver:6443", skip_tls=False, ca_cert_path=None ) + assert token_auth.login() == ("Logged into testserver:6443") + assert token_auth.logout() == ("Successfully logged out of testserver:6443") def test_token_auth_login_tls(mocker): - mocker.patch("openshift.invoke", side_effect=arg_side_effect) - mock_res = mocker.patch.object(openshift.Result, "out") - mock_res.side_effect = lambda: att_side_effect_tls(fake_res) - - # FIXME - Pytest mocker not allowing caught exception - # token_auth = TokenAuthentication(token="testtoken", server="testserver") - # assert token_auth.login() == "Error: certificate auth failure, please set `skip_tls=True` in TokenAuthentication" + mocker.patch.object(client, "ApiClient") token_auth = TokenAuthentication( - token="testtoken", server="testserver:6443", skip_tls=True + token="testtoken", server="testserver:6443", skip_tls=True, ca_cert_path=None ) - assert token_auth.login() == ( - "login", - ["--token=testtoken", "--server=testserver:6443", "--insecure-skip-tls-verify"], + assert token_auth.login() == ("Logged into testserver:6443") + token_auth = TokenAuthentication( + token="testtoken", server="testserver:6443", skip_tls=False, ca_cert_path=None ) + assert token_auth.login() == ("Logged into testserver:6443") + token_auth = TokenAuthentication( + token="testtoken", + server="testserver:6443", + skip_tls=False, + ca_cert_path="path/to/cert", + ) + assert token_auth.login() == ("Logged into testserver:6443") -def test_passwd_auth_creation(): - try: - passwd_auth = PasswordUserAuthentication() - assert passwd_auth.username == None - assert passwd_auth.password == None - - passwd_auth = PasswordUserAuthentication("user") - assert passwd_auth.username == "user" - assert passwd_auth.password == None - - passwd_auth = PasswordUserAuthentication("user", "passwd") - assert passwd_auth.username == "user" - assert passwd_auth.password == "passwd" - - passwd_auth = PasswordUserAuthentication("user", password="passwd") - assert passwd_auth.username == "user" - assert passwd_auth.password == "passwd" - - passwd_auth = PasswordUserAuthentication(username="user", password="passwd") - assert passwd_auth.username == "user" - assert passwd_auth.password == "passwd" - - except Exception: - assert 0 == 1 - - -def test_passwd_auth_login_logout(mocker): - mocker.patch("openshift.invoke", side_effect=arg_side_effect) - mocker.patch("openshift.login", side_effect=arg_side_effect) - mock_res = mocker.patch.object(openshift.Result, "out") - mock_res.side_effect = lambda: att_side_effect(fake_res) +def test_load_kube_config(mocker): + mocker.patch.object(config, "load_kube_config") + kube_config_auth = KubeConfigFileAuthentication( + kube_config_path="/path/to/your/config" + ) + response = kube_config_auth.load_kube_config() - token_auth = PasswordUserAuthentication(username="user", password="passwd") - assert token_auth.login() == ("user", "passwd") - assert token_auth.logout() == ("logout",) + assert ( + response + == "Loaded user config file at path %s" % kube_config_auth.kube_config_path + ) def test_auth_coverage():