Skip to content

Commit cd91bc4

Browse files
authored
[dy] Add OIDC generic provider (mage-ai#4563)
* [dy] Add OIDC generic provider * [dy] Update environment variables * [dy] Delete test code
1 parent 4062b13 commit cd91bc4

File tree

8 files changed

+200
-5
lines changed

8 files changed

+200
-5
lines changed

mage_ai/authentication/oauth/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class ProviderName(str, Enum):
1616
GHE = 'ghe'
1717
GOOGLE = 'google'
1818
OKTA = 'okta'
19+
OIDC_GENERIC = 'oidc_generic'
1920

2021

2122
VALID_OAUTH_PROVIDERS = [e.value for e in ProviderName]

mage_ai/authentication/providers/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from mage_ai.authentication.providers.ghe import GHEProvider
55
from mage_ai.authentication.providers.gitlab import GitlabProvider
66
from mage_ai.authentication.providers.google import GoogleProvider
7+
from mage_ai.authentication.providers.oidc import OidcProvider
78
from mage_ai.authentication.providers.okta import OktaProvider
89

910
NAME_TO_PROVIDER = {
@@ -12,5 +13,6 @@
1213
ProviderName.GHE: GHEProvider,
1314
ProviderName.GITLAB: GitlabProvider,
1415
ProviderName.GOOGLE: GoogleProvider,
16+
ProviderName.OIDC_GENERIC: OidcProvider,
1517
ProviderName.OKTA: OktaProvider,
1618
}

mage_ai/authentication/providers/gitlab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def get_auth_url_response(self, redirect_uri: str = None, **kwargs) -> Dict:
4545
f'{base_url}/oauth',
4646
),
4747
response_type='code',
48-
state=uuid.uuid4().hex,
4948
scope='read_user+write_repository+api',
49+
state=uuid.uuid4().hex,
5050
)
5151
query_strings = []
5252
for k, v in query.items():
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import urllib.parse
2+
import uuid
3+
from typing import Awaitable, Dict
4+
5+
import aiohttp
6+
import requests
7+
8+
from mage_ai.authentication.oauth.constants import ProviderName
9+
from mage_ai.authentication.providers.oauth import OauthProvider
10+
from mage_ai.authentication.providers.sso import SsoProvider
11+
from mage_ai.authentication.providers.utils import get_base_url
12+
from mage_ai.server.logger import Logger
13+
from mage_ai.settings.sso import OIDC_CLIENT_ID, OIDC_CLIENT_SECRET, OIDC_DISCOVERY_URL
14+
15+
logger = Logger().new_server_logger(__name__)
16+
17+
18+
class OidcProvider(OauthProvider, SsoProvider):
19+
provider = ProviderName.OIDC_GENERIC
20+
21+
def __init__(self):
22+
self.discovery_url = OIDC_DISCOVERY_URL
23+
self.client_id = OIDC_CLIENT_ID
24+
self.client_secret = OIDC_CLIENT_SECRET
25+
self.__validate()
26+
27+
self.discover()
28+
29+
def __validate(self):
30+
if not self.discovery_url:
31+
raise Exception(
32+
'OIDC discovery url is empty. '
33+
'Make sure the OIDC_DISCOVERY_URL environment variable is set.')
34+
if not self.client_id:
35+
raise Exception(
36+
'OIDC client id is empty. '
37+
'Make sure the OIDC_CLIENT_ID environment variable is set.')
38+
39+
def discover(self) -> Dict:
40+
"""
41+
Call discovery url to get the endpoints for the OIDC server
42+
"""
43+
try:
44+
response = requests.get(
45+
self.discovery_url,
46+
headers={
47+
'Accept': 'application/json',
48+
},
49+
timeout=10,
50+
)
51+
52+
response.raise_for_status()
53+
except Exception:
54+
logger.exception('Could not fetch response from OIDC discovery url')
55+
raise
56+
57+
data = response.json()
58+
59+
self.authorization_endpoint = data.get('authorization_endpoint')
60+
self.token_endpoint = data.get('token_endpoint')
61+
self.userinfo_endpoint = data.get('userinfo_endpoint')
62+
63+
def get_auth_url_response(self, redirect_uri: str = None, **kwargs) -> Dict:
64+
base_url = get_base_url(redirect_uri)
65+
redirect_uri_query = dict(
66+
provider=self.provider,
67+
redirect_uri=redirect_uri,
68+
)
69+
query = dict(
70+
client_id=self.client_id,
71+
redirect_uri=urllib.parse.quote_plus(
72+
f'{base_url}/oauth',
73+
),
74+
response_type='code',
75+
scope='openid profile email',
76+
state=uuid.uuid4().hex,
77+
)
78+
query_strings = []
79+
for k, v in query.items():
80+
query_strings.append(f'{k}={v}')
81+
82+
return dict(
83+
url=f"{self.authorization_endpoint}?{'&'.join(query_strings)}",
84+
redirect_query_params=redirect_uri_query,
85+
)
86+
87+
async def get_access_token_response(self, code: str, **kwargs) -> Awaitable[Dict]:
88+
base_url = get_base_url(kwargs.get('redirect_uri'))
89+
data = dict()
90+
91+
payload = dict(
92+
client_id=self.client_id,
93+
grant_type='authorization_code',
94+
code=code,
95+
redirect_uri=f'{base_url}/oauth',
96+
)
97+
98+
if self.client_secret:
99+
payload['client_secret'] = self.client_secret
100+
101+
async with aiohttp.ClientSession() as session:
102+
async with session.post(
103+
self.token_endpoint,
104+
headers={
105+
'Accept': 'application/json',
106+
},
107+
data=payload,
108+
timeout=20,
109+
) as response:
110+
data = await response.json()
111+
112+
return data
113+
114+
async def get_user_info(self, access_token: str = None, **kwargs) -> Awaitable[Dict]:
115+
if access_token is None:
116+
raise Exception('Access token is required to fetch user info.')
117+
async with aiohttp.ClientSession() as session:
118+
async with session.get(
119+
self.userinfo_endpoint,
120+
headers={
121+
'Authorization': f'Bearer {access_token}',
122+
},
123+
timeout=10,
124+
) as response:
125+
userinfo_resp = await response.json()
126+
127+
email = userinfo_resp.get('email')
128+
129+
return dict(
130+
email=email,
131+
username=userinfo_resp.get('preferred_username', email),
132+
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import React, { useMemo } from 'react';
2+
import { useRouter } from 'next/router';
3+
4+
import KeyboardShortcutButton from '@oracle/elements/Button/KeyboardShortcutButton';
5+
import { SignInProps } from './constants';
6+
import { queryFromUrl } from '@utils/url';
7+
import { set } from '@storage/localStorage';
8+
9+
type OidcSignInProps = {} & SignInProps;
10+
11+
function OidcSignIn({
12+
oauthResponse,
13+
}: OidcSignInProps) {
14+
const router = useRouter();
15+
const {
16+
url: oauthUrl,
17+
redirect_query_params: redirectQueryParams = {},
18+
} = useMemo(() => oauthResponse || {}, [oauthResponse]);
19+
20+
return (
21+
<>
22+
{oauthUrl && (
23+
<KeyboardShortcutButton
24+
bold
25+
inline
26+
onClick={() => {
27+
const q = queryFromUrl(oauthUrl);
28+
const state = q.state;
29+
set(state, redirectQueryParams);
30+
router.push(oauthUrl);
31+
}}
32+
uuid="SignForm/oidc_generic"
33+
>
34+
Sign in with OIDC
35+
</KeyboardShortcutButton>
36+
)}
37+
</>
38+
);
39+
}
40+
41+
export default OidcSignIn;

mage_ai/frontend/interfaces/OauthType.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import GoogleSignIn from '@components/Sessions/GoogleSignIn';
22
import MicrosoftSignIn from '@components/Sessions/MicrosoftSignIn';
3+
import OidcSignIn from '@components/Sessions/OidcSignIn';
34
import OktaSignIn from '@components/Sessions/OktaSignIn';
45

56
export enum OauthProviderEnum {
@@ -8,12 +9,14 @@ export enum OauthProviderEnum {
89
GITHUB = 'github',
910
GITLAB = 'gitlab',
1011
GOOGLE = 'google',
12+
OIDC_GENERIC = 'oidc_generic',
1113
OKTA = 'okta',
1214
}
1315

1416
export const OAUTH_PROVIDER_SIGN_IN_MAPPING = {
1517
[OauthProviderEnum.ACTIVE_DIRECTORY]: MicrosoftSignIn,
1618
[OauthProviderEnum.GOOGLE]: GoogleSignIn,
19+
[OauthProviderEnum.OIDC_GENERIC]: OidcSignIn,
1720
[OauthProviderEnum.OKTA]: OktaSignIn,
1821
};
1922

mage_ai/settings/__init__.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,31 @@ def get_bool_value(value: str) -> bool:
165165
'SENTRY_DSN',
166166
'SENTRY_TRACES_SAMPLE_RATE',
167167
'MAGE_PUBLIC_HOST',
168-
'ACTIVE_DIRECTORY_DIRECTORY_ID',
169-
'ACTIVE_DIRECTORY_CLIENT_ID',
170-
'ACTIVE_DIRECTORY_CLIENT_SECRET',
171168
'SCHEDULER_TRIGGER_INTERVAL',
172169
'REQUIRE_USER_PERMISSIONS',
173170
'ENABLE_PROMETHEUS',
174171
'OTEL_EXPORTER_OTLP_ENDPOINT',
175172
'OTEL_EXPORTER_OTLP_HTTP_ENDPOINT',
173+
'MAX_FILE_CACHE_SIZE',
174+
# Oauth variables
175+
'ACTIVE_DIRECTORY_DIRECTORY_ID',
176+
'ACTIVE_DIRECTORY_CLIENT_ID',
177+
'ACTIVE_DIRECTORY_CLIENT_SECRET',
176178
'OKTA_DOMAIN_URL',
177179
'OKTA_CLIENT_ID',
178180
'OKTA_CLIENT_SECRET',
179181
'GOOGLE_CLIENT_ID',
180182
'GOOGLE_CLIENT_SECRET',
183+
'OIDC_CLIENT_ID',
184+
'OIDC_CLIENT_SECRET',
185+
'OIDC_DISCOVERY_URL',
181186
'GHE_CLIENT_ID',
182187
'GHE_CLIENT_SECRET',
183188
'GHE_HOSTNAME',
184-
'MAX_FILE_CACHE_SIZE',
189+
'BITBUCKET_HOST',
190+
'BITBUCKET_OAUTH_KEY',
191+
'BITBUCKET_OAUTH_SECRET',
192+
'GITLAB_HOST',
193+
'GITLAB_CLIENT_ID',
194+
'GITLAB_CLIENT_SECRET',
185195
]

mage_ai/settings/sso.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@
1919
ACTIVE_DIRECTORY_CLIENT_ID = os.getenv('ACTIVE_DIRECTORY_CLIENT_ID')
2020
ACTIVE_DIRECTORY_CLIENT_SECRET = os.getenv('ACTIVE_DIRECTORY_CLIENT_SECRET')
2121
ACTIVE_DIRECTORY_ROLES_MAPPING = os.getenv('ACTIVE_DIRECTORY_ROLES_MAPPING')
22+
23+
# OIDC
24+
25+
OIDC_CLIENT_ID = os.getenv('OIDC_CLIENT_ID')
26+
OIDC_CLIENT_SECRET = os.getenv('OIDC_CLIENT_SECRET')
27+
OIDC_DISCOVERY_URL = os.getenv('OIDC_DISCOVERY_URL')

0 commit comments

Comments
 (0)