7
7
import os
8
8
from datetime import datetime , timedelta
9
9
from enum import Enum , unique
10
+ import subprocess
10
11
from typing import Optional , Tuple
11
12
from kaggle_web_client import KaggleWebClient
12
13
from kaggle_web_client import (CredentialError , BackendError )
@@ -80,6 +81,28 @@ def get_gcloud_credential(self) -> str:
80
81
else :
81
82
raise
82
83
84
+ def set_gcloud_credentials (self , project = None , account = None ):
85
+ """Set user credentials attached to the current kernel and optionally the project & account name to the `gcloud` CLI.
86
+
87
+ Example usage:
88
+ client = UserSecretsClient()
89
+ client.set_gcloud_credentials(project="my-gcp-project", account="[email protected] ")
90
+
91
+ !gcloud ai-platform jobs list
92
+ """
93
+ creds = self .get_gcloud_credential ()
94
+ creds_path = self ._write_credentials_file (creds )
95
+
96
+ subprocess .run (['gcloud' , 'config' , 'set' , 'auth/credential_file_override' , creds_path ])
97
+
98
+ if project :
99
+ os .environ ['GOOGLE_CLOUD_PROJECT' ] = project
100
+ subprocess .run (['gcloud' , 'config' , 'set' , 'project' , project ])
101
+
102
+ if account :
103
+ os .environ ['GOOGLE_ACCOUNT' ] = account
104
+ subprocess .run (['gcloud' , 'config' , 'set' , 'account' , account ])
105
+
83
106
def set_tensorflow_credential (self , credential ):
84
107
"""Sets the credential for use by Tensorflow both in the local notebook
85
108
and to pass to the TPU.
@@ -89,11 +112,7 @@ def set_tensorflow_credential(self, credential):
89
112
90
113
# Write to a local JSON credentials file and set
91
114
# GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook.
92
- adc_path = os .path .join (
93
- os .environ .get ('HOME' , '/' ), 'gcloud_credential.json' )
94
- with open (adc_path , 'w' ) as f :
95
- f .write (credential )
96
- os .environ ['GOOGLE_APPLICATION_CREDENTIALS' ]= adc_path
115
+ self ._write_credentials_file (credential )
97
116
98
117
# set the credential for the TPU
99
118
tensorflow_gcs_config .configure_gcs (credentials = credential )
@@ -108,6 +127,14 @@ def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
108
127
token, expiry = client.get_bigquery_access_token()
109
128
"""
110
129
return self ._get_access_token (GcpTarget .BIGQUERY )
130
+
131
+ def _write_credentials_file (self , credentials ) -> str :
132
+ adc_path = os .path .join (os .environ .get ('HOME' , '/' ), 'gcloud_credential.json' )
133
+ with open (adc_path , 'w' ) as f :
134
+ f .write (credentials )
135
+ os .environ ['GOOGLE_APPLICATION_CREDENTIALS' ]= adc_path
136
+
137
+ return adc_path
111
138
112
139
def _get_gcs_access_token (self ) -> Tuple [str , Optional [datetime ]]:
113
140
return self ._get_access_token (GcpTarget .GCS )
0 commit comments