Skip to content

Commit e03ed29

Browse files
committed
new
1 parent 7206230 commit e03ed29

File tree

2 files changed

+69
-124
lines changed

2 files changed

+69
-124
lines changed

Diff for: tests/test_api.py

+69-48
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import unittest
2+
from unittest.mock import patch, MagicMock
23
from flask import Flask
3-
from jwt_proxy.api import blueprint, proxy_request
44
import json
5-
from unittest.mock import patch, MagicMock
65
import jwt
7-
from jwt_proxy.api import CustomJSONProvider
6+
from jwt_proxy.api import blueprint, proxy_request, CustomJSONProvider, validate_jwt
87

98
class TestAuthBlueprint(unittest.TestCase):
109
def setUp(self):
@@ -45,51 +44,6 @@ def test_proxy_request(self, mock_request):
4544
response = proxy_request(req, 'http://example.com/api')
4645
self.assertEqual(response, "Plain text response")
4746

48-
@patch('jwt.PyJWKClient')
49-
@patch('jwt.decode')
50-
def test_validate_jwt(self, mock_decode, mock_jwk_client):
51-
"""Test JWT validation and proxying"""
52-
# Set up mock JWKClient
53-
mock_key = MagicMock()
54-
mock_jwk_client_instance = MagicMock()
55-
mock_jwk_client_instance.get_signing_key_from_jwt.return_value = mock_key
56-
mock_jwk_client.return_value = mock_jwk_client_instance
57-
58-
# Set up mock JWT decoding
59-
mock_decode.return_value = {'email': '[email protected]'}
60-
self.app.json = CustomJSONProvider(self.app)
61-
62-
# Test whitelisted path without token
63-
response = self.client.get('/whitelisted', content_type='application/json')
64-
print(f'Status Code: {response.status_code}')
65-
print(f'Response Data: {response.data.decode()}')
66-
print(f'Response JSON: {response.json}')
67-
self.assertEqual(response.status_code, 200)
68-
69-
# Test valid token
70-
response = self.client.get('/', headers={'Authorization': 'Bearer valid_token'})
71-
print(f'Status Code: {response.status_code}')
72-
print(f'Response Data: {response.data.decode()}')
73-
print(f'Response JSON: {response.json}')
74-
self.assertEqual(response.status_code, 200)
75-
76-
# Test missing token
77-
response = self.client.get('/')
78-
print(f'Status Code: {response.status_code}')
79-
print(f'Response Data: {response.data.decode()}')
80-
print(f'Response JSON: {response.json}')
81-
self.assertEqual(response.status_code, 400)
82-
self.assertEqual(response.json.get('message'), "token missing")
83-
84-
# Test expired token
85-
mock_decode.side_effect = jwt.exceptions.ExpiredSignatureError()
86-
response = self.client.get('/', headers={'Authorization': 'Bearer expired_token'})
87-
print(f'Status Code: {response.status_code}')
88-
print(f'Response Data: {response.data.decode()}')
89-
print(f'Response JSON: {response.json}')
90-
self.assertEqual(response.status_code, 401)
91-
self.assertEqual(response.json.get('message'), "token expired")
92-
9347
def test_smart_configuration(self):
9448
"""Test /fhir/.well-known/smart-configuration endpoint"""
9549
response = self.client.get('/fhir/.well-known/smart-configuration')
@@ -117,5 +71,72 @@ def test_config_settings(self):
11771
response = self.client.get('/settings/SECRET_KEY')
11872
self.assertEqual(response.status_code, 400)
11973

74+
class TestValidateJWT(unittest.TestCase):
75+
def setUp(self):
76+
self.app = Flask(__name__)
77+
self.app.config["PATH_WHITELIST"] = ["/allowed_path"]
78+
self.app.config["UPSTREAM_SERVER"] = "http://upstream-server"
79+
self.app.config["JWKS_URL"] = "http://jwks-url"
80+
81+
# Register the route using the validate_jwt function
82+
@self.app.route("/", defaults={"relative_path": ""}, methods=["GET", "POST"])
83+
@self.app.route("/<path:relative_path>", methods=["GET", "POST"])
84+
def validate_jwt_route(relative_path):
85+
return validate_jwt(relative_path)
86+
87+
self.client = self.app.test_client()
88+
89+
@patch('jwt_proxy.api.proxy_request') # Adjust the import path based on where proxy_request is defined
90+
def test_path_whitelist(self, mock_proxy_request):
91+
# Mock response directly without using jsonify
92+
mock_proxy_request.return_value = {"message": "request proxied"}
93+
94+
with self.app.app_context():
95+
response = self.client.get("/allowed_path")
96+
97+
self.assertEqual(response.status_code, 200)
98+
self.assertEqual(response.json, {"message": "request proxied"})
99+
100+
@patch('jwt_proxy.api.proxy_request') # Adjust the import path based on where proxy_request is defined
101+
@patch('jwt.PyJWKClient')
102+
@patch('jwt.decode')
103+
def test_valid_token(self, mock_decode, mock_jwks_client, mock_proxy_request):
104+
mock_proxy_request.return_value = {"message": "request proxied"}
105+
mock_jwks_client.return_value.get_signing_key_from_jwt.return_value.key = "test-key"
106+
mock_decode.return_value = {"email": "[email protected]"}
107+
108+
headers = {"Authorization": "Bearer valid-token"}
109+
110+
with self.app.app_context():
111+
response = self.client.get("/some_path", headers=headers)
112+
113+
self.assertEqual(response.status_code, 200)
114+
self.assertEqual(response.json, {"message": "request proxied"})
115+
116+
@patch('jwt_proxy.api.proxy_request') # Adjust the import path based on where proxy_request is defined
117+
@patch('jwt.PyJWKClient')
118+
@patch('jwt.decode')
119+
def test_missing_token(self, mock_decode, mock_jwks_client, mock_proxy_request):
120+
with self.app.app_context():
121+
response = self.client.get("/some_path")
122+
123+
self.assertEqual(response.status_code, 400)
124+
self.assertEqual(response.json, {"message": "token missing"})
125+
126+
@patch('jwt_proxy.api.proxy_request') # Adjust the import path based on where proxy_request is defined
127+
@patch('jwt.PyJWKClient')
128+
@patch('jwt.decode')
129+
def test_expired_token(self, mock_decode, mock_jwks_client, mock_proxy_request):
130+
mock_jwks_client.return_value.get_signing_key_from_jwt.return_value.key = "test-key"
131+
mock_decode.side_effect = jwt.exceptions.ExpiredSignatureError("token expired")
132+
133+
headers = {"Authorization": "Bearer expired-token"}
134+
135+
with self.app.app_context():
136+
response = self.client.get("/some_path", headers=headers)
137+
138+
self.assertEqual(response.status_code, 401)
139+
self.assertEqual(response.json, {"message": "token expired"})
140+
120141
if __name__ == '__main__':
121142
unittest.main()

Diff for: tests/test_validation.py

-76
This file was deleted.

0 commit comments

Comments
 (0)