|
1 | 1 | import unittest
|
| 2 | +from unittest.mock import patch, MagicMock |
2 | 3 | from flask import Flask
|
3 |
| -from jwt_proxy.api import blueprint, proxy_request |
4 | 4 | import json
|
5 |
| -from unittest.mock import patch, MagicMock |
6 | 5 | import jwt
|
7 |
| -from jwt_proxy.api import CustomJSONProvider |
| 6 | +from jwt_proxy.api import blueprint, proxy_request, CustomJSONProvider, validate_jwt |
8 | 7 |
|
9 | 8 | class TestAuthBlueprint(unittest.TestCase):
|
10 | 9 | def setUp(self):
|
@@ -45,51 +44,6 @@ def test_proxy_request(self, mock_request):
|
45 | 44 | response = proxy_request(req, 'http://example.com/api')
|
46 | 45 | self.assertEqual(response, "Plain text response")
|
47 | 46 |
|
48 |
| - @patch('jwt.PyJWKClient') |
49 |
| - @patch('jwt.decode') |
50 |
| - def test_validate_jwt(self, mock_decode, mock_jwk_client): |
51 |
| - """Test JWT validation and proxying""" |
52 |
| - # Set up mock JWKClient |
53 |
| - mock_key = MagicMock() |
54 |
| - mock_jwk_client_instance = MagicMock() |
55 |
| - mock_jwk_client_instance.get_signing_key_from_jwt.return_value = mock_key |
56 |
| - mock_jwk_client.return_value = mock_jwk_client_instance |
57 |
| - |
58 |
| - # Set up mock JWT decoding |
59 |
| - mock_decode. return_value = { 'email': '[email protected]'} |
60 |
| - self.app.json = CustomJSONProvider(self.app) |
61 |
| - |
62 |
| - # Test whitelisted path without token |
63 |
| - response = self.client.get('/whitelisted', content_type='application/json') |
64 |
| - print(f'Status Code: {response.status_code}') |
65 |
| - print(f'Response Data: {response.data.decode()}') |
66 |
| - print(f'Response JSON: {response.json}') |
67 |
| - self.assertEqual(response.status_code, 200) |
68 |
| - |
69 |
| - # Test valid token |
70 |
| - response = self.client.get('/', headers={'Authorization': 'Bearer valid_token'}) |
71 |
| - print(f'Status Code: {response.status_code}') |
72 |
| - print(f'Response Data: {response.data.decode()}') |
73 |
| - print(f'Response JSON: {response.json}') |
74 |
| - self.assertEqual(response.status_code, 200) |
75 |
| - |
76 |
| - # Test missing token |
77 |
| - response = self.client.get('/') |
78 |
| - print(f'Status Code: {response.status_code}') |
79 |
| - print(f'Response Data: {response.data.decode()}') |
80 |
| - print(f'Response JSON: {response.json}') |
81 |
| - self.assertEqual(response.status_code, 400) |
82 |
| - self.assertEqual(response.json.get('message'), "token missing") |
83 |
| - |
84 |
| - # Test expired token |
85 |
| - mock_decode.side_effect = jwt.exceptions.ExpiredSignatureError() |
86 |
| - response = self.client.get('/', headers={'Authorization': 'Bearer expired_token'}) |
87 |
| - print(f'Status Code: {response.status_code}') |
88 |
| - print(f'Response Data: {response.data.decode()}') |
89 |
| - print(f'Response JSON: {response.json}') |
90 |
| - self.assertEqual(response.status_code, 401) |
91 |
| - self.assertEqual(response.json.get('message'), "token expired") |
92 |
| - |
93 | 47 | def test_smart_configuration(self):
|
94 | 48 | """Test /fhir/.well-known/smart-configuration endpoint"""
|
95 | 49 | response = self.client.get('/fhir/.well-known/smart-configuration')
|
@@ -117,5 +71,72 @@ def test_config_settings(self):
|
117 | 71 | response = self.client.get('/settings/SECRET_KEY')
|
118 | 72 | self.assertEqual(response.status_code, 400)
|
119 | 73 |
|
| 74 | +class TestValidateJWT(unittest.TestCase): |
| 75 | + def setUp(self): |
| 76 | + self.app = Flask(__name__) |
| 77 | + self.app.config["PATH_WHITELIST"] = ["/allowed_path"] |
| 78 | + self.app.config["UPSTREAM_SERVER"] = "http://upstream-server" |
| 79 | + self.app.config["JWKS_URL"] = "http://jwks-url" |
| 80 | + |
| 81 | + # Register the route using the validate_jwt function |
| 82 | + @self.app.route("/", defaults={"relative_path": ""}, methods=["GET", "POST"]) |
| 83 | + @self.app.route("/<path:relative_path>", methods=["GET", "POST"]) |
| 84 | + def validate_jwt_route(relative_path): |
| 85 | + return validate_jwt(relative_path) |
| 86 | + |
| 87 | + self.client = self.app.test_client() |
| 88 | + |
| 89 | + @patch('jwt_proxy.api.proxy_request') # Adjust the import path based on where proxy_request is defined |
| 90 | + def test_path_whitelist(self, mock_proxy_request): |
| 91 | + # Mock response directly without using jsonify |
| 92 | + mock_proxy_request.return_value = {"message": "request proxied"} |
| 93 | + |
| 94 | + with self.app.app_context(): |
| 95 | + response = self.client.get("/allowed_path") |
| 96 | + |
| 97 | + self.assertEqual(response.status_code, 200) |
| 98 | + self.assertEqual(response.json, {"message": "request proxied"}) |
| 99 | + |
| 100 | + @patch('jwt_proxy.api.proxy_request') # Adjust the import path based on where proxy_request is defined |
| 101 | + @patch('jwt.PyJWKClient') |
| 102 | + @patch('jwt.decode') |
| 103 | + def test_valid_token(self, mock_decode, mock_jwks_client, mock_proxy_request): |
| 104 | + mock_proxy_request.return_value = {"message": "request proxied"} |
| 105 | + mock_jwks_client.return_value.get_signing_key_from_jwt.return_value.key = "test-key" |
| 106 | + mock_decode. return_value = { "email": "[email protected]"} |
| 107 | + |
| 108 | + headers = {"Authorization": "Bearer valid-token"} |
| 109 | + |
| 110 | + with self.app.app_context(): |
| 111 | + response = self.client.get("/some_path", headers=headers) |
| 112 | + |
| 113 | + self.assertEqual(response.status_code, 200) |
| 114 | + self.assertEqual(response.json, {"message": "request proxied"}) |
| 115 | + |
| 116 | + @patch('jwt_proxy.api.proxy_request') # Adjust the import path based on where proxy_request is defined |
| 117 | + @patch('jwt.PyJWKClient') |
| 118 | + @patch('jwt.decode') |
| 119 | + def test_missing_token(self, mock_decode, mock_jwks_client, mock_proxy_request): |
| 120 | + with self.app.app_context(): |
| 121 | + response = self.client.get("/some_path") |
| 122 | + |
| 123 | + self.assertEqual(response.status_code, 400) |
| 124 | + self.assertEqual(response.json, {"message": "token missing"}) |
| 125 | + |
| 126 | + @patch('jwt_proxy.api.proxy_request') # Adjust the import path based on where proxy_request is defined |
| 127 | + @patch('jwt.PyJWKClient') |
| 128 | + @patch('jwt.decode') |
| 129 | + def test_expired_token(self, mock_decode, mock_jwks_client, mock_proxy_request): |
| 130 | + mock_jwks_client.return_value.get_signing_key_from_jwt.return_value.key = "test-key" |
| 131 | + mock_decode.side_effect = jwt.exceptions.ExpiredSignatureError("token expired") |
| 132 | + |
| 133 | + headers = {"Authorization": "Bearer expired-token"} |
| 134 | + |
| 135 | + with self.app.app_context(): |
| 136 | + response = self.client.get("/some_path", headers=headers) |
| 137 | + |
| 138 | + self.assertEqual(response.status_code, 401) |
| 139 | + self.assertEqual(response.json, {"message": "token expired"}) |
| 140 | + |
120 | 141 | if __name__ == '__main__':
|
121 | 142 | unittest.main()
|
0 commit comments