diff --git a/environment.yml b/environment.yml index 9c36689..896305c 100644 --- a/environment.yml +++ b/environment.yml @@ -15,3 +15,4 @@ dependencies: - pytest-cov - pytest-mock - rbc +- requests diff --git a/pymapd/_samlutils.py b/pymapd/_samlutils.py new file mode 100644 index 0000000..5c057ba --- /dev/null +++ b/pymapd/_samlutils.py @@ -0,0 +1,77 @@ +import re +import requests +from html import unescape +from urllib.parse import urlparse + + +def get_saml_response(idpurl, + username, + password, + userformfield, + passwordformfield, + sslverify=True): + """ + Obtains the SAML response from an Identity Provider + given the provided username and password. + + Parameters + ---------- + idpurl : str + The logon page of the SAML Identity Provider + username : str + SAML Username + password : str + SAML Password + userformfield : str + The HTML form ID for the username + passwordformfield : str + The HTML form ID for the password + sslverify : bool, optional + Verify TLS certificates, by default True + """ + + session = requests.Session() + + response = session.get(idpurl, verify=sslverify) + initialurl = response.url + formaction = initialurl + # print(page.content) + + # Determine if there's an action in the form, if there is, + # use it instead of the page URL + asearch = re.search(r'', + response.text, re.IGNORECASE | re.DOTALL) + + if asearch: + formaction = asearch.group(1) + + # If the action is a path not a URL, build the full + if not formaction.lower().startswith('http'): + parsedurl = urlparse(idpurl) + formaction = parsedurl.scheme + "://" + parsedurl.netloc + formaction + + # Un-urlencode the URL + formaction = unescape(formaction) + + formpayload = { + userformfield: username, + passwordformfield: password + } + + response = session.post(formaction, data=formpayload, verify=sslverify) + + samlresponse = None + ssearch = re.search(r'', + response.text, re.IGNORECASE | re.DOTALL) + if ssearch: + samlresponse = ssearch.group(1) + # Remove any whitespace, some providers include + # new lines in the response (!) + re.sub(r"[\r\n\t\s]*", "", samlresponse) + + if not samlresponse: + raise ValueError('No SAMLResponse found in response.') + + return samlresponse diff --git a/pymapd/connection.py b/pymapd/connection.py index d5958d3..df1da6a 100644 --- a/pymapd/connection.py +++ b/pymapd/connection.py @@ -27,6 +27,7 @@ from .ipc import load_buffer, shmdt from ._pandas_loaders import build_row_desc, _serialize_arrow_payload from . import _pandas_loaders +from ._samlutils import get_saml_response from packaging.version import Version @@ -47,6 +48,10 @@ def connect(uri=None, sessionid=None, bin_cert_validate=None, bin_ca_certs=None, + idpurl=None, + idpformusernamefield='username', + idpformpasswordfield='password', + idpsslverify=True, ): """ Create a new Connection. @@ -65,6 +70,15 @@ def connect(uri=None, Whether to continue if there is any certificate error bin_ca_certs: str, optional, binary encrypted connection only Path to the CA certificate file + idpurl : str + EXPERIMENTAL Enable SAML authentication by providing + the logon page of the SAML Identity Provider. + idpformusernamefield: str + The HTML form ID for the username, defaults to 'username'. + idpformpasswordfield: str + The HTML form ID for the password, defaults to 'password'. + idpsslverify: str + Enable / disable certificate checking, defaults to True. Returns ------- @@ -82,6 +96,10 @@ def connect(uri=None, >>> connect(user='admin', password='HyperInteractive', host='localhost', ... port=6274, dbname='omnisci') + >>> connect(user='admin', password='HyperInteractive', host='localhost', + ... port=443, idpurl='https://sso.localhost/logon', + protocol='https') + >>> connect(sessionid='XihlkjhdasfsadSDoasdllMweieisdpo', host='localhost', ... port=6273, protocol='http') @@ -89,7 +107,10 @@ def connect(uri=None, return Connection(uri=uri, user=user, password=password, host=host, port=port, dbname=dbname, protocol=protocol, sessionid=sessionid, bin_cert_validate=bin_cert_validate, - bin_ca_certs=bin_ca_certs) + bin_ca_certs=bin_ca_certs, idpurl=idpurl, + idpformusernamefield=idpformusernamefield, + idpformpasswordfield=idpformpasswordfield, + idpsslverify=idpsslverify) def _parse_uri(uri): @@ -146,13 +167,17 @@ def __init__(self, sessionid=None, bin_cert_validate=None, bin_ca_certs=None, + idpurl=None, + idpformusernamefield='username', + idpformpasswordfield='password', + idpsslverify=True, ): self.sessionid = None if sessionid is not None: - if any([user, password, uri, dbname]): + if any([user, password, uri, dbname, idpurl]): raise TypeError("Cannot specify sessionid with user, password," - " dbname, or uri") + " dbname, uri, or idpurl") if uri is not None: if not all([user is None, password is None, @@ -161,7 +186,8 @@ def __init__(self, dbname is None, protocol == 'binary', bin_cert_validate is None, - bin_ca_certs is None]): + bin_ca_certs is None, + idpurl is None]): raise TypeError("Cannot specify both URI and other arguments") user, password, host, port, dbname, protocol, \ bin_cert_validate, bin_ca_certs = _parse_uri(uri) @@ -220,6 +246,21 @@ def __init__(self, self.get_tables() self.sessionid = sessionid else: + if idpurl: + self._user = '' + self._password = get_saml_response( + username=user, + password=password, + idpurl=idpurl, + userformfield=idpformusernamefield, + passwordformfield=idpformpasswordfield, + sslverify=idpsslverify) + self._dbname = '' + self._idpsslverify = idpsslverify + user = self._user + password = self._password + dbname = self._dbname + self._session = self._client.connect(user, password, dbname) except TMapDException as e: raise _translate_exception(e) from e diff --git a/setup.py b/setup.py index 3c90ef5..63bb74a 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,8 @@ 'pandas >= 1.0,<2.0', 'pyarrow >= 0.12.0,<0.14', 'packaging >= 20.0', + 'requests >= 2.23.0', + 'numba >= 0.48', 'rbc-project == 0.2.0dev0'] # Optional Requirements