Skip to content

Commit 2e29a53

Browse files
authored
Merge pull request Styria-Digital#84 from ashokdelphia/improve-logout-security
Improve logout security
2 parents adafd74 + e27daad commit 2e29a53

File tree

12 files changed

+286
-27
lines changed

12 files changed

+286
-27
lines changed

changelog.d/84.feature.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Apply 'blacklist' to any token from the same line of refreshed tokens as any invalidated token, where token ids are available.
2+
Avoid storing whole auth tokens when JWT_TOKEN_ID setting is set to 'require'.
3+
Please see notes in docs on migrating from JWT_TOKEN_ID 'allow' (the default) to 'require' (recommended).

docs/index.md

+6
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ urlpatterns = [
201201

202202
When called, deletes all blacklisted tokens that have expired.
203203

204+
### Warning
205+
206+
Unless `JWT_TOKEN_ID` is set to `require`, blacklisting tokens will store the entire token value. This creates a potential problem if someone is able to read and delete records from the blacklist, either directly in the database or via the administrative interface. Note that the default value is `include`, not `require`. See the section on `JWT_TOKEN_ID` for how to migrate to requiring token id claims in all tokens.
207+
204208
## Additional Settings
205209
There are some additional settings that you can override similar to how you'd do it with Django REST framework itself. Here are all the available defaults.
206210

@@ -369,6 +373,8 @@ For new installations, please override the default and set this to `require`, as
369373

370374
For existing installations, when migrating from an older version (pre-1.17) or when changing the setting from `off`, we recommend setting this to `require` once all of the valid tokens have the id claims. This will typically be after `JWT_EXPIRATION_DELTA` has elapsed since upgrading or allowing id claims to be included.
371375

376+
Note that when set to `off` or `include`, the blacklist functionality - if used - will store the entire token value, which would allow someone with access to the administrative interface, or directly to the database, to steal an otherwise valid token and remove it from the blacklist. Using `require` for this setting means that only token identifiers are recorded for the blacklist and not entire tokens.
377+
372378
Default is `include`.
373379

374380
### JWT_AUDIENCE

src/rest_framework_jwt/authentication.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,6 @@ def authenticate(self, request):
6868
except MissingToken:
6969
return None
7070

71-
if apps.is_installed('rest_framework_jwt.blacklist'):
72-
from rest_framework_jwt.blacklist.models import BlacklistedToken
73-
if BlacklistedToken.objects.filter(token=force_str(token)).exists():
74-
msg = _('Token is blacklisted.')
75-
raise exceptions.PermissionDenied(msg)
76-
7771
try:
7872
payload = self.jwt_decode_token(token)
7973
except jwt.ExpiredSignature:
@@ -86,6 +80,12 @@ def authenticate(self, request):
8680
msg = _('Invalid token.')
8781
raise exceptions.AuthenticationFailed(msg)
8882

83+
if apps.is_installed('rest_framework_jwt.blacklist'):
84+
from rest_framework_jwt.blacklist.models import BlacklistedToken
85+
if BlacklistedToken.is_blocked(token, payload):
86+
msg = _('Token is blacklisted.')
87+
raise exceptions.PermissionDenied(msg)
88+
8989
user = self.authenticate_credentials(payload)
9090

9191
return user, token
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Generated by Django 3.1.4 on 2020-12-12 12:15
2+
3+
from django.db import migrations, models
4+
from django import VERSION
5+
6+
import jwt
7+
8+
9+
def add_token_id_values(apps, schema_editor):
10+
"""Fill in token_id values for existing token records that have a token id
11+
"""
12+
BlacklistedToken = apps.get_model('blacklist', 'BlacklistedToken')
13+
for row in BlacklistedToken.objects.filter(token_id=None):
14+
payload = jwt.decode(row.token)
15+
token_id = payload.get('orig_jti') or payload.get('jti')
16+
if token_id:
17+
row.token_id = token_id
18+
row.save(update_fields=['token_id'])
19+
20+
21+
class Migration(migrations.Migration):
22+
23+
dependencies = [
24+
('blacklist', '0001_initial'),
25+
]
26+
27+
operations = [
28+
migrations.AddField(
29+
model_name='blacklistedtoken',
30+
name='token_id',
31+
field=models.UUIDField(db_index=True, null=True),
32+
),
33+
migrations.AlterField(
34+
model_name='blacklistedtoken',
35+
name='token',
36+
field=models.TextField(db_index=True, null=True),
37+
),
38+
migrations.RunPython(add_token_id_values, reverse_code=migrations.RunPython.noop)
39+
]
40+
if VERSION >= (2, 2):
41+
operations.append(
42+
migrations.AddConstraint(
43+
model_name='blacklistedtoken',
44+
constraint=models.CheckConstraint(
45+
check=models.Q(token_id__isnull=False) | models.Q(token__isnull=False),
46+
name='token_or_id_not_null',
47+
),
48+
)
49+
)

src/rest_framework_jwt/blacklist/models.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
# -*- coding: utf-8 -*-
22

3+
from django import VERSION
34
from django.conf import settings
45
from django.db import models
6+
from django.db.models import Q
57
from django.utils import timezone
8+
from django.utils.encoding import force_str
9+
10+
from rest_framework_jwt.settings import api_settings
611

712

813
class BlacklistedTokenManager(models.Manager):
@@ -11,7 +16,18 @@ def delete_stale_tokens(self):
1116

1217

1318
class BlacklistedToken(models.Model):
14-
token = models.TextField(db_index=True)
19+
class Meta:
20+
if VERSION >= (2, 2):
21+
constraints = [
22+
models.CheckConstraint(
23+
check=Q(token_id__isnull=False) | Q(token__isnull=False),
24+
name='token_or_id_not_null',
25+
)
26+
]
27+
28+
# This holds the original token id for refreshed tokens with ids
29+
token_id = models.UUIDField(db_index=True, null=True)
30+
token = models.TextField(db_index=True, null=True)
1531
user = models.ForeignKey(
1632
settings.AUTH_USER_MODEL, on_delete=models.CASCADE)
1733
expires_at = models.DateTimeField(db_index=True)
@@ -21,3 +37,22 @@ class BlacklistedToken(models.Model):
2137

2238
def __str__(self):
2339
return 'Blacklisted token - {} - {}'.format(self.user, self.token)
40+
41+
42+
@staticmethod
43+
def is_blocked(token, payload):
44+
token = force_str(token)
45+
46+
# For invalidated tokens that have an original token id (orig_jti),
47+
# we record that in the list, so that the whole family of tokens
48+
# refreshed from the same initial token is rejected.
49+
token_id = payload.get('orig_jti') or payload.get('jti')
50+
51+
if api_settings.JWT_TOKEN_ID == 'require':
52+
query = Q(token_id=token_id)
53+
elif api_settings.JWT_TOKEN_ID == 'off':
54+
query = Q(token=token)
55+
else:
56+
query = Q(token__isnull=False, token=token) | Q(token_id__isnull=False, token_id=token_id)
57+
58+
return BlacklistedToken.objects.filter(query).exists()

src/rest_framework_jwt/blacklist/permissions.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import jwt
2+
13
from rest_framework.permissions import BasePermission
24

35
from rest_framework_jwt.authentication import JSONWebTokenAuthentication
@@ -9,6 +11,12 @@ class IsNotBlacklisted(BasePermission):
911
message = _('You have been blacklisted.')
1012

1113
def has_permission(self, request, view):
12-
return not BlacklistedToken.objects.filter(
13-
token=JSONWebTokenAuthentication.get_token_from_request(request)
14-
).exists()
14+
token = JSONWebTokenAuthentication.get_token_from_request(request)
15+
16+
# Don't check the blacklist for requests with no token.
17+
if token is None:
18+
return True
19+
20+
# The token should already be validated before we call this.
21+
payload = jwt.decode(token, None, False)
22+
return not BlacklistedToken.is_blocked(token, payload)

src/rest_framework_jwt/blacklist/serializers.py

+12
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,21 @@ def save(self, **kwargs):
3737
iat = payload.get('iat', unix_epoch())
3838
expires_at_unix_time = iat + api_settings.JWT_EXPIRATION_DELTA.total_seconds()
3939

40+
# For refreshed tokens, record the token id of the original token.
41+
# This allows us to invalidate the whole family of tokens from
42+
# the same original authentication event.
43+
token_id = payload.get('orig_jti') or payload.get('jti')
44+
4045
self.validated_data.update({
46+
'token_id': token_id,
4147
'user': check_user(payload),
4248
'expires_at':
4349
make_aware(datetime.utcfromtimestamp(expires_at_unix_time)),
4450
})
51+
52+
# Don't store the token if we can rely on token IDs.
53+
# The token values are still sensitive until they expire.
54+
if api_settings.JWT_TOKEN_ID == 'require':
55+
del self.validated_data['token']
56+
4557
return super(BlacklistTokenSerializer, self).save(**kwargs)

src/rest_framework_jwt/utils.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from django.apps import apps
1212
from django.contrib.auth import get_user_model
13-
from django.utils.encoding import force_str
1413

1514
from rest_framework import serializers
1615
from rest_framework.utils.encoders import JSONEncoder
@@ -204,12 +203,6 @@ def jwt_create_response_payload(
204203
def check_payload(token):
205204
from rest_framework_jwt.authentication import JSONWebTokenAuthentication
206205

207-
if apps.is_installed('rest_framework_jwt.blacklist'):
208-
from rest_framework_jwt.blacklist.models import BlacklistedToken
209-
if BlacklistedToken.objects.filter(token=force_str(token)).exists():
210-
msg = _('Token is blacklisted.')
211-
raise serializers.ValidationError(msg)
212-
213206
try:
214207
payload = JSONWebTokenAuthentication.jwt_decode_token(token)
215208
except jwt.ExpiredSignature:
@@ -222,6 +215,12 @@ def check_payload(token):
222215
msg = _('Invalid token.')
223216
raise serializers.ValidationError(msg)
224217

218+
if apps.is_installed('rest_framework_jwt.blacklist'):
219+
from rest_framework_jwt.blacklist.models import BlacklistedToken
220+
if BlacklistedToken.is_blocked(token, payload):
221+
msg = _('Token is blacklisted.')
222+
raise serializers.ValidationError(msg)
223+
225224
return payload
226225

227226

tests/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from django.contrib.auth import get_user_model
88

9+
from rest_framework.reverse import reverse
910
from rest_framework.test import APIClient
1011

1112
from rest_framework_jwt.authentication import JSONWebTokenAuthentication
@@ -80,3 +81,13 @@ def _create_user(**kwargs):
8081
return User.objects.create_user(**kwargs)
8182

8283
return _create_user
84+
85+
86+
@pytest.fixture
87+
def call_auth_refresh_endpoint(api_client, db):
88+
def _call_auth_refresh_endpoint(token):
89+
url = reverse("auth-refresh")
90+
data = {"token": token}
91+
return api_client.post(path=url, data=data)
92+
93+
return _call_auth_refresh_endpoint
+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from datetime import timedelta
4+
from django.utils import timezone
5+
6+
import pytest
7+
8+
from rest_framework_jwt.authentication import JSONWebTokenAuthentication
9+
from rest_framework_jwt.blacklist.models import BlacklistedToken
10+
from rest_framework_jwt.settings import api_settings
11+
12+
import uuid
13+
14+
15+
@pytest.mark.parametrize(
16+
'id_setting', ['require', 'include']
17+
)
18+
def test_token_is_blocked_by_id(user, monkeypatch, id_setting):
19+
monkeypatch.setattr(api_settings, "JWT_TOKEN_ID", id_setting)
20+
payload = JSONWebTokenAuthentication.jwt_create_payload(user)
21+
token = JSONWebTokenAuthentication.jwt_encode_payload(payload)
22+
23+
expiration = timezone.now() + timedelta(days=1)
24+
BlacklistedToken(
25+
token_id=payload['jti'],
26+
expires_at=expiration,
27+
user=user,
28+
).save()
29+
30+
assert BlacklistedToken.is_blocked(token, payload) is True
31+
32+
33+
34+
@pytest.mark.parametrize(
35+
'id_setting', ['require', 'include']
36+
)
37+
def test_refreshed_token_is_blocked_by_original_id(user, call_auth_refresh_endpoint, monkeypatch, id_setting):
38+
monkeypatch.setattr(api_settings, "JWT_TOKEN_ID", id_setting)
39+
original_payload = JSONWebTokenAuthentication.jwt_create_payload(user)
40+
original_token = JSONWebTokenAuthentication.jwt_encode_payload(original_payload)
41+
42+
refresh_response = call_auth_refresh_endpoint(original_token)
43+
refreshed_token = refresh_response.json()['token']
44+
payload = JSONWebTokenAuthentication.jwt_decode_token(refreshed_token)
45+
46+
expiration = timezone.now() + timedelta(days=1)
47+
BlacklistedToken(
48+
token_id=original_payload['jti'],
49+
expires_at=expiration,
50+
user=user,
51+
).save()
52+
53+
assert BlacklistedToken.is_blocked(refreshed_token, payload) is True
54+
55+
56+
@pytest.mark.parametrize(
57+
'id_setting', ['include', 'off']
58+
)
59+
def test_token_is_blocked_by_value(user, monkeypatch, id_setting):
60+
monkeypatch.setattr(api_settings, "JWT_TOKEN_ID", id_setting)
61+
payload = JSONWebTokenAuthentication.jwt_create_payload(user)
62+
token = JSONWebTokenAuthentication.jwt_encode_payload(payload)
63+
64+
expiration = timezone.now() + timedelta(days=1)
65+
BlacklistedToken(
66+
token=token,
67+
expires_at=expiration,
68+
user=user,
69+
).save()
70+
71+
assert BlacklistedToken.is_blocked(token, payload) is True
72+
73+
74+
def test_token_is_not_blocked_by_value_when_ids_required(user, monkeypatch):
75+
monkeypatch.setattr(api_settings, "JWT_TOKEN_ID", "require")
76+
payload = JSONWebTokenAuthentication.jwt_create_payload(user)
77+
token = JSONWebTokenAuthentication.jwt_encode_payload(payload)
78+
79+
expiration = timezone.now() + timedelta(days=1)
80+
BlacklistedToken(
81+
token=token,
82+
expires_at=expiration,
83+
user=user,
84+
).save()
85+
86+
assert BlacklistedToken.is_blocked(token, payload) is False
87+
88+
89+
def test_token_is_not_blocked_by_id_when_ids_disabled(user, monkeypatch):
90+
monkeypatch.setattr(api_settings, "JWT_TOKEN_ID", "off")
91+
payload = JSONWebTokenAuthentication.jwt_create_payload(user)
92+
payload['jti'] = uuid.uuid4()
93+
token = JSONWebTokenAuthentication.jwt_encode_payload(payload)
94+
95+
expiration = timezone.now() + timedelta(days=1)
96+
BlacklistedToken(
97+
token_id=payload['jti'],
98+
expires_at=expiration,
99+
user=user,
100+
).save()
101+
102+
assert BlacklistedToken.is_blocked(token, payload) is False

tests/views/conftest.py

-10
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,3 @@ def _call_auth_verify_endpoint(token):
2525
return api_client.post(path=url, data=data)
2626

2727
return _call_auth_verify_endpoint
28-
29-
30-
@pytest.fixture
31-
def call_auth_refresh_endpoint(api_client, db):
32-
def _call_auth_refresh_endpoint(token):
33-
url = reverse("auth-refresh")
34-
data = {"token": token}
35-
return api_client.post(path=url, data=data)
36-
37-
return _call_auth_refresh_endpoint

0 commit comments

Comments
 (0)