Skip to content

Commit 6f6af47

Browse files
committed
Add more tests
1 parent 5d05ffc commit 6f6af47

File tree

2 files changed

+105
-5
lines changed

2 files changed

+105
-5
lines changed

jwt_proxy/api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def proxy_request(req, upstream_url, user_info=None):
4545
return result
4646

4747

48+
4849
@blueprint.route("/", defaults={"relative_path": ""}, methods=SUPPORTED_METHODS)
4950
@blueprint.route("/<path:relative_path>", methods=SUPPORTED_METHODS)
5051
def validate_jwt(relative_path):
@@ -56,23 +57,23 @@ def validate_jwt(relative_path):
5657
)
5758
return response_content
5859

59-
token = request.headers.get("Authorization", "").split("Bearer ")[-1]
60+
token = request.headers.get("authorization", "").split("Bearer ")[-1]
6061
if not token:
6162
return jsonify(message="token missing"), 400
6263

64+
jwks_client = jwt.PyJWKClient(current_app.config["JWKS_URL"])
65+
signing_key = jwks_client.get_signing_key_from_jwt(token)
66+
6367
try:
64-
jwks_client = jwt.PyJWKClient(current_app.config["JWKS_URL"])
65-
signing_key = jwks_client.get_signing_key_from_jwt(token)
6668
decoded_token = jwt.decode(
6769
jwt=token,
70+
# TODO cache public key in redis
6871
key=signing_key.key,
6972
algorithms=("RS256"),
7073
audience=("account"),
7174
)
7275
except jwt.exceptions.ExpiredSignatureError:
7376
return jsonify(message="token expired"), 401
74-
except Exception as e:
75-
return jsonify(message=str(e)), 400
7677

7778
response_content = proxy_request(
7879
req=request,

tests/test_validation.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
from flask import Flask, jsonify, request
4+
import jwt
5+
6+
# Assume blueprint and validate_jwt function are defined in your application
7+
# For testing purposes, we'll use a simple Flask app
8+
app = Flask(__name__)
9+
app.config["PATH_WHITELIST"] = ["/allowed_path"]
10+
app.config["UPSTREAM_SERVER"] = "http://upstream-server"
11+
app.config["JWKS_URL"] = "http://jwks-url"
12+
13+
@app.route("/", defaults={"relative_path": ""}, methods=["GET", "POST"])
14+
@app.route("/<path:relative_path>", methods=["GET", "POST"])
15+
def validate_jwt(relative_path):
16+
"""Validate JWT and pass to upstream server"""
17+
if f"/{relative_path}" in app.config["PATH_WHITELIST"]:
18+
response_content = proxy_request(
19+
req=request,
20+
upstream_url=f"{app.config['UPSTREAM_SERVER']}/{relative_path}",
21+
)
22+
return response_content
23+
24+
token = request.headers.get("authorization", "").split("Bearer ")[-1]
25+
if not token:
26+
return jsonify(message="token missing"), 400
27+
28+
jwks_client = jwt.PyJWKClient(app.config["JWKS_URL"])
29+
signing_key = jwks_client.get_signing_key_from_jwt(token)
30+
31+
try:
32+
decoded_token = jwt.decode(
33+
jwt=token,
34+
key=signing_key.key,
35+
algorithms=("RS256"),
36+
audience=("account"),
37+
)
38+
except jwt.exceptions.ExpiredSignatureError:
39+
return jsonify(message="token expired"), 401
40+
41+
response_content = proxy_request(
42+
req=request,
43+
upstream_url=f"{app.config['UPSTREAM_SERVER']}/{relative_path}",
44+
user_info=decoded_token.get("email") or decoded_token.get("preferred_username"),
45+
)
46+
return response_content
47+
48+
def proxy_request(req, upstream_url, user_info=None):
49+
# Dummy implementation for testing purposes
50+
return jsonify(message="request proxied")
51+
52+
class TestValidateJWT(unittest.TestCase):
53+
54+
def setUp(self):
55+
app.testing = True
56+
self.client = app.test_client()
57+
58+
@patch('your_module.proxy_request')
59+
def test_path_whitelist(self, mock_proxy_request):
60+
mock_proxy_request.return_value = jsonify(message="request proxied")
61+
response = self.client.get("/allowed_path")
62+
self.assertEqual(response.status_code, 200)
63+
self.assertEqual(response.json, {"message": "request proxied"})
64+
65+
@patch('your_module.proxy_request')
66+
@patch('jwt.PyJWKClient')
67+
@patch('jwt.decode')
68+
def test_valid_token(self, mock_decode, mock_jwks_client, mock_proxy_request):
69+
mock_proxy_request.return_value = jsonify(message="request proxied")
70+
mock_jwks_client.return_value.get_signing_key_from_jwt.return_value.key = "test-key"
71+
mock_decode.return_value = {"email": "[email protected]"}
72+
73+
headers = {"Authorization": "Bearer valid-token"}
74+
response = self.client.get("/some_path", headers=headers)
75+
self.assertEqual(response.status_code, 200)
76+
self.assertEqual(response.json, {"message": "request proxied"})
77+
78+
@patch('your_module.proxy_request')
79+
@patch('jwt.PyJWKClient')
80+
@patch('jwt.decode')
81+
def test_missing_token(self, mock_decode, mock_jwks_client, mock_proxy_request):
82+
response = self.client.get("/some_path")
83+
self.assertEqual(response.status_code, 400)
84+
self.assertEqual(response.json, {"message": "token missing"})
85+
86+
@patch('your_module.proxy_request')
87+
@patch('jwt.PyJWKClient')
88+
@patch('jwt.decode')
89+
def test_expired_token(self, mock_decode, mock_jwks_client, mock_proxy_request):
90+
mock_jwks_client.return_value.get_signing_key_from_jwt.return_value.key = "test-key"
91+
mock_decode.side_effect = jwt.exceptions.ExpiredSignatureError("token expired")
92+
93+
headers = {"Authorization": "Bearer expired-token"}
94+
response = self.client.get("/some_path", headers=headers)
95+
self.assertEqual(response.status_code, 401)
96+
self.assertEqual(response.json, {"message": "token expired"})
97+
98+
if __name__ == '__main__':
99+
unittest.main()

0 commit comments

Comments
 (0)