Skip to content

Commit c0f7053

Browse files
authored
Merge pull request #889 from Kaggle/set-gcloud-credentials
Add set_gcloud_credentials to UserSecretsClient
2 parents 393beb5 + 42394fa commit c0f7053

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

patches/kaggle_secrets.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
from datetime import datetime, timedelta
99
from enum import Enum, unique
10+
import subprocess
1011
from typing import Optional, Tuple
1112
from kaggle_web_client import KaggleWebClient
1213
from kaggle_web_client import (CredentialError, BackendError)
@@ -80,6 +81,28 @@ def get_gcloud_credential(self) -> str:
8081
else:
8182
raise
8283

84+
def set_gcloud_credentials(self, project=None, account=None):
85+
"""Set user credentials attached to the current kernel and optionally the project & account name to the `gcloud` CLI.
86+
87+
Example usage:
88+
client = UserSecretsClient()
89+
client.set_gcloud_credentials(project="my-gcp-project", account="[email protected]")
90+
91+
!gcloud ai-platform jobs list
92+
"""
93+
creds = self.get_gcloud_credential()
94+
creds_path = self._write_credentials_file(creds)
95+
96+
subprocess.run(['gcloud', 'config', 'set', 'auth/credential_file_override', creds_path])
97+
98+
if project:
99+
os.environ['GOOGLE_CLOUD_PROJECT'] = project
100+
subprocess.run(['gcloud', 'config', 'set', 'project', project])
101+
102+
if account:
103+
os.environ['GOOGLE_ACCOUNT'] = account
104+
subprocess.run(['gcloud', 'config', 'set', 'account', account])
105+
83106
def set_tensorflow_credential(self, credential):
84107
"""Sets the credential for use by Tensorflow both in the local notebook
85108
and to pass to the TPU.
@@ -89,11 +112,7 @@ def set_tensorflow_credential(self, credential):
89112

90113
# Write to a local JSON credentials file and set
91114
# GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook.
92-
adc_path = os.path.join(
93-
os.environ.get('HOME', '/'), 'gcloud_credential.json')
94-
with open(adc_path, 'w') as f:
95-
f.write(credential)
96-
os.environ['GOOGLE_APPLICATION_CREDENTIALS']=adc_path
115+
self._write_credentials_file(credential)
97116

98117
# set the credential for the TPU
99118
tensorflow_gcs_config.configure_gcs(credentials=credential)
@@ -108,6 +127,14 @@ def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
108127
token, expiry = client.get_bigquery_access_token()
109128
"""
110129
return self._get_access_token(GcpTarget.BIGQUERY)
130+
131+
def _write_credentials_file(self, credentials) -> str:
132+
adc_path = os.path.join(os.environ.get('HOME', '/'), 'gcloud_credential.json')
133+
with open(adc_path, 'w') as f:
134+
f.write(credentials)
135+
os.environ['GOOGLE_APPLICATION_CREDENTIALS']=adc_path
136+
137+
return adc_path
111138

112139
def _get_gcs_access_token(self) -> Tuple[str, Optional[datetime]]:
113140
return self._get_access_token(GcpTarget.GCS)

tests/test_user_secrets.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import subprocess
34
import threading
45
import unittest
56
from http.server import BaseHTTPRequestHandler, HTTPServer
@@ -134,6 +135,34 @@ def call_get_secret():
134135
'/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"},
135136
success=False)
136137

138+
def test_set_gcloud_credentials_succeeds(self):
139+
secret = '{"client_id":"gcloud","type":"authorized_user"}'
140+
project = 'foo'
141+
account = 'bar'
142+
143+
def get_gcloud_config_value(field):
144+
result = subprocess.run(['gcloud', 'config', 'get-value', field], capture_output=True)
145+
result.check_returncode()
146+
return result.stdout.strip().decode('ascii')
147+
148+
def test_fn():
149+
client = UserSecretsClient()
150+
client.set_gcloud_credentials(project=project, account=account)
151+
152+
self.assertEqual(project, os.environ['GOOGLE_CLOUD_PROJECT'])
153+
self.assertEqual(project, get_gcloud_config_value('project'))
154+
155+
self.assertEqual(account, os.environ['GOOGLE_ACCOUNT'])
156+
self.assertEqual(account, get_gcloud_config_value('account'))
157+
158+
expected_creds_file = '/tmp/gcloud_credential.json'
159+
self.assertEqual(expected_creds_file, os.environ['GOOGLE_APPLICATION_CREDENTIALS'])
160+
self.assertEqual(expected_creds_file, get_gcloud_config_value('auth/credential_file_override'))
161+
162+
with open(expected_creds_file, 'r') as f:
163+
self.assertEqual(secret, '\n'.join(f.readlines()))
164+
165+
self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret)
137166

138167
@mock.patch('kaggle_secrets.datetime')
139168
def test_get_access_token_succeeds(self, mock_dt):

0 commit comments

Comments
 (0)