1
1
import unittest
2
- from unittest .mock import patch , MagicMock
2
+ from unittest .mock import patch
3
3
from flask import Flask , jsonify , request
4
4
import jwt
5
-
6
- # Assume blueprint and validate_jwt function are defined in your application
7
- # For testing purposes, we'll use a simple Flask app
8
- app = Flask (__name__ )
9
- app .config ["PATH_WHITELIST" ] = ["/allowed_path" ]
10
- app .config ["UPSTREAM_SERVER" ] = "http://upstream-server"
11
- app .config ["JWKS_URL" ] = "http://jwks-url"
12
-
13
- @app .route ("/" , defaults = {"relative_path" : "" }, methods = ["GET" , "POST" ])
14
- @app .route ("/<path:relative_path>" , methods = ["GET" , "POST" ])
15
- def validate_jwt (relative_path ):
16
- """Validate JWT and pass to upstream server"""
17
- if f"/{ relative_path } " in app .config ["PATH_WHITELIST" ]:
18
- response_content = proxy_request (
19
- req = request ,
20
- upstream_url = f"{ app .config ['UPSTREAM_SERVER' ]} /{ relative_path } " ,
21
- )
22
- return response_content
23
-
24
- token = request .headers .get ("authorization" , "" ).split ("Bearer " )[- 1 ]
25
- if not token :
26
- return jsonify (message = "token missing" ), 400
27
-
28
- jwks_client = jwt .PyJWKClient (app .config ["JWKS_URL" ])
29
- signing_key = jwks_client .get_signing_key_from_jwt (token )
30
-
31
- try :
32
- decoded_token = jwt .decode (
33
- jwt = token ,
34
- key = signing_key .key ,
35
- algorithms = ("RS256" ),
36
- audience = ("account" ),
37
- )
38
- except jwt .exceptions .ExpiredSignatureError :
39
- return jsonify (message = "token expired" ), 401
40
-
41
- response_content = proxy_request (
42
- req = request ,
43
- upstream_url = f"{ app .config ['UPSTREAM_SERVER' ]} /{ relative_path } " ,
44
- user_info = decoded_token .get ("email" ) or decoded_token .get ("preferred_username" ),
45
- )
46
- return response_content
47
-
48
- def proxy_request (req , upstream_url , user_info = None ):
49
- # Dummy implementation for testing purposes
50
- return jsonify (message = "request proxied" )
5
+ from jwt_proxy .api import validate_jwt
51
6
52
7
class TestValidateJWT (unittest .TestCase ):
53
8
54
9
def setUp (self ):
55
- app .testing = True
10
+ app = Flask (__name__ )
11
+ app .config ["PATH_WHITELIST" ] = ["/allowed_path" ]
12
+ app .config ["UPSTREAM_SERVER" ] = "http://upstream-server"
13
+ app .config ["JWKS_URL" ] = "http://jwks-url"
14
+
15
+ @app .route ("/" , defaults = {"relative_path" : "" }, methods = ["GET" , "POST" ])
16
+ @app .route ("/<path:relative_path>" , methods = ["GET" , "POST" ])
17
+ def validate_jwt_route (relative_path ):
18
+ return validate_jwt (relative_path )
19
+
20
+ self .app = app
56
21
self .client = app .test_client ()
57
22
58
- @patch ('jwt_proxy.api.proxy_request' )
23
+ @patch ('jwt_proxy.api.proxy_request' ) # Adjust the import path for proxy_request
59
24
def test_path_whitelist (self , mock_proxy_request ):
60
25
# Mock response as a Flask Response object directly
61
- mock_proxy_request .return_value = self . client . get ( '/allowed_path' )
26
+ mock_proxy_request .return_value = jsonify ( message = "request proxied" )
62
27
response = self .client .get ("/allowed_path" )
63
28
self .assertEqual (response .status_code , 200 )
64
29
self .assertEqual (response .json , {"message" : "request proxied" })
65
30
66
- @patch ('jwt_proxy.api.proxy_request' )
31
+ @patch ('jwt_proxy.api.proxy_request' ) # Adjust the import path for proxy_request
67
32
@patch ('jwt.PyJWKClient' )
68
33
@patch ('jwt.decode' )
69
34
def test_valid_token (self , mock_decode , mock_jwks_client , mock_proxy_request ):
70
- mock_proxy_request .return_value = self . client . get ( '/some_path' )
35
+ mock_proxy_request .return_value = jsonify ( message = "request proxied" )
71
36
mock_jwks_client .return_value .get_signing_key_from_jwt .return_value .key = "test-key"
72
37
mock_decode .
return_value = {
"email" :
"[email protected] " }
73
38
@@ -76,15 +41,15 @@ def test_valid_token(self, mock_decode, mock_jwks_client, mock_proxy_request):
76
41
self .assertEqual (response .status_code , 200 )
77
42
self .assertEqual (response .json , {"message" : "request proxied" })
78
43
79
- @patch ('jwt_proxy.api.proxy_request' )
44
+ @patch ('jwt_proxy.api.proxy_request' ) # Adjust the import path for proxy_request
80
45
@patch ('jwt.PyJWKClient' )
81
46
@patch ('jwt.decode' )
82
47
def test_missing_token (self , mock_decode , mock_jwks_client , mock_proxy_request ):
83
48
response = self .client .get ("/some_path" )
84
49
self .assertEqual (response .status_code , 400 )
85
50
self .assertEqual (response .json , {"message" : "token missing" })
86
51
87
- @patch ('jwt_proxy.api.proxy_request' )
52
+ @patch ('jwt_proxy.api.proxy_request' ) # Adjust the import path for proxy_request
88
53
@patch ('jwt.PyJWKClient' )
89
54
@patch ('jwt.decode' )
90
55
def test_expired_token (self , mock_decode , mock_jwks_client , mock_proxy_request ):
@@ -98,4 +63,3 @@ def test_expired_token(self, mock_decode, mock_jwks_client, mock_proxy_request):
98
63
99
64
if __name__ == '__main__' :
100
65
unittest .main ()
101
-
0 commit comments