diff --git a/graphistry/arrow_uploader.py b/graphistry/arrow_uploader.py index a96b17c5b3..19ba78c7fc 100644 --- a/graphistry/arrow_uploader.py +++ b/graphistry/arrow_uploader.py @@ -322,10 +322,8 @@ def sso_get_token(self, state): # from .pygraphistry import PyGraphistry base_path = self.server_base_path - out = requests.get( - f'{base_path}/api/v2/o/sso/oidc/jwt/{state}/', - verify=self.certificate_validation - ) + url = f'{base_path}/api/v2/o/sso/oidc/jwt/{state}/' + out = requests.get(url,verify=self.certificate_validation) json_response = None try: json_response = out.json() @@ -340,6 +338,8 @@ def sso_get_token(self, state): if 'active_organization' in json_response['data']: logger.debug("@ArrowUploader.sso_get_token, org_name: %s", json_response['data']['active_organization']['slug']) self.org_name = json_response['data']['active_organization']['slug'] + if 'state' in json_response['data']: + self.state = json_response['data']['state'] except Exception as e: logger.error('Unexpected SSO authentication error: %s', out, exc_info=True) diff --git a/graphistry/pygraphistry.py b/graphistry/pygraphistry.py index 98916f6a01..94cdde87c5 100644 --- a/graphistry/pygraphistry.py +++ b/graphistry/pygraphistry.py @@ -32,6 +32,7 @@ ############################################################################### SSO_GET_TOKEN_ELAPSE_SECONDS = 50 +SSO_STATE_SPLIT = "|" EnvVarNames = { "api_key": "GRAPHISTRY_API_KEY", @@ -250,12 +251,12 @@ def sso_login(org_name=None, idp_name=None, sso_timeout=SSO_GET_TOKEN_ELAPSE_SEC return auth_url @staticmethod - def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type): - """Internal function to handle what to do with the auth_url + def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type='browser', relogin=False): + """Internal function to handle what to do with the auth_url based on the client mode python/ipython console or notebook. :param auth_url: SSO auth url retrieved via API - :type auth_url: str + :type auth_url: str or list in list([[name1,url1], [name2,url2], [name3,url3]) :param sso_timeout: Set sso login getting token timeout in seconds (blocking mode), set to None if non-blocking mode. Default as SSO_GET_TOKEN_ELAPSE_SECONDS. :type sso_timeout: Optional[int] :param sso_opt_into_type: Show the SSO url with display(), webbrowser.open(), or print() @@ -270,7 +271,18 @@ def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type): if in_ipython() or in_databricks() or sso_opt_into_type == 'display': # If run in notebook, just display the HTML # from IPython.core.display import HTML from IPython.display import display, HTML - display(HTML(f'Login SSO')) + if isinstance(auth_url ,list): + for auth_url_each in auth_url: + # pop up window + # display(HTML(f'{auth_url_each[0]}')) + display(HTML(f'{auth_url_each[0]}')) + else: + # display(HTML(f'Login SSO')) + display(HTML(f'Login SSO')) + if relogin: + print("Please click the above link to open browser to access SSO organization") + else: + print("Please click the above link to open browser to login") print("Please click the above URL to open browser to login") print(f"If you cannot see the URL, please open browser, browse to this URL: {auth_url}") print("Please close browser tab after SSO login to back to notebook") @@ -279,9 +291,31 @@ def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type): print("Please minimize browser after your SSO login and go back to pygraphistry") import webbrowser - input("Press Enter to open browser ...") - # open browser to auth_url - webbrowser.open(auth_url) + if isinstance(auth_url ,list): + if len(auth_url) == 1: + print(f"idp name: {auth_url[0][0]}") + input("Press Enter to open browser ...") + webbrowser.open(auth_url[0][1]) + else: + while True: + url_dict = {} + for index, (name, url) in enumerate(auth_url, start=1): + url_dict[str(index)] = url + print(f"{index}: {name}") + input_key = input("Enter a number above to open the browser or 'quit' to exit: ") + + if input_key in url_dict: + selected_url = url_dict[input_key] + webbrowser.open(selected_url) + break + elif input_key.strip().lower() == 'quit': + break + else: + print("Invalid key. No URL found.") + else: + input("Press Enter to open browser ...") + # open browser to auth_url + webbrowser.open(auth_url) else: print(f"Please open a browser, browse to this URL, and sign in: {auth_url}") print("After, if you get timeout error, run graphistry.sso_get_token() to complete the authentication") @@ -313,7 +347,7 @@ def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type): # set org_name to sso org PyGraphistry._config['org_name'] = org_name - print("Successfully logged in") + print(f"Successfully logged in, current active organization is {org_name}") return PyGraphistry.api_token() else: return None @@ -2405,10 +2439,33 @@ def switch_org(value): result = PyGraphistry._handle_api_response(response) if result is True: - PyGraphistry._config['org_name'] = value.strip() - logger.info("Switched to organization: {}".format(value.strip())) + PyGraphistry._api_response_switch_org(response) else: # print the error message raise Exception(result) + + @staticmethod + def _api_response_switch_org(response): + try: + json_response = response.json() + message = json_response.get('message', '') + data = json_response.get('data', '') + if message.startswith('Switch to organization'): + PyGraphistry._config['org_name'] = data['organization_slug'] + logger.info("Switched to organization: {}".format(data['organization_slug'])) + elif message.startswith('Login to SSO for switch to organization') or message.startswith('Choose SSO for switch to organization'): + idp_name_with_url_list = [] + idp_state_list = [] + for idp_name in data['idp']: + idp_name_with_url_list.append([idp_name, data['idp'][idp_name]['auth_url']]) + idp_state_list.append(data['idp'][idp_name]['state']) + multiple_idp_state = SSO_STATE_SPLIT.join(idp_state_list) + PyGraphistry.sso_state(multiple_idp_state) + PyGraphistry._handle_auth_url(idp_name_with_url_list, sso_timeout=SSO_GET_TOKEN_ELAPSE_SECONDS, relogin=True) + else: + return message + except: + logger.error('Error: %s', response, exc_info=True) + raise Exception("Unknown Error") @staticmethod def _handle_api_response(response):