Skip to content

Add decorator for user-defined authorization check #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions s3_file_field/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

from django.conf import settings
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.core import signing
from django.http import HttpRequest, HttpResponse
from django.shortcuts import resolve_url
from django.utils.module_loading import import_string
from rest_framework import serializers
from rest_framework.decorators import api_view, parser_classes
from rest_framework.parsers import JSONParser
Expand Down Expand Up @@ -79,7 +86,61 @@ class FinalizationResponseSerializer(serializers.Serializer):
field_value = serializers.CharField(trim_whitespace=False)


def no_check(request: HttpRequest, *args, **kwargs) -> bool:
return True


def is_site_user(request: HttpRequest, *args, **kwargs) -> bool:
return request.user.is_authenticated


def get_can_user_upload():
return import_string(
getattr(
settings,
"S3_FILE_FIELD_USER_PERMISSION",
"s3_file_field.views.no_check",
)
)


# Simplified version of django.contrib.auth.decorators.user_passes_test, except instead of passing
# only request.user to the test_func, we pass the full request object, *args, and **kwargs
def request_passes_test(test_func):
"""
Check that a request passes the provided user-supplied test function.

Decorator for views that checks that the user passes the given test,
redirecting to the log-in page if necessary. The test should be a callable
that takes the user object and returns True if the user passes.
"""

def decorator(view_func):
@wraps(view_func)
def _wrapper_view(request: HttpRequest, *args, **kwargs) -> HttpResponse:
if test_func(request, *args, **kwargs):
return view_func(request, *args, **kwargs)
path = request.build_absolute_uri()
resolved_login_url = resolve_url(settings.LOGIN_URL)
# If the login url is the same scheme and net location then just
# use the path as the "next" url.
login_scheme, login_netloc = urlparse(resolved_login_url)[:2]
current_scheme, current_netloc = urlparse(path)[:2]
if (not login_scheme or login_scheme == current_scheme) and (
not login_netloc or login_netloc == current_netloc
):
path = request.get_full_path()
from django.contrib.auth.views import redirect_to_login

return redirect_to_login(path, resolved_login_url, REDIRECT_FIELD_NAME)

return _wrapper_view

return decorator


@api_view(["POST"])
@request_passes_test(get_can_user_upload())
@parser_classes([JSONParser])
def upload_initialize(request: Request) -> HttpResponseBase:
request_serializer = UploadInitializationRequestSerializer(data=request.data)
Expand Down Expand Up @@ -126,6 +187,7 @@ def upload_initialize(request: Request) -> HttpResponseBase:


@api_view(["POST"])
@request_passes_test(get_can_user_upload())
@parser_classes([JSONParser])
def upload_complete(request: Request) -> HttpResponseBase:
request_serializer = UploadCompletionRequestSerializer(data=request.data)
Expand Down Expand Up @@ -160,6 +222,7 @@ def upload_complete(request: Request) -> HttpResponseBase:


@api_view(["POST"])
@request_passes_test(get_can_user_upload())
@parser_classes([JSONParser])
def finalize(request: Request) -> HttpResponseBase:
request_serializer = FinalizationRequestSerializer(data=request.data)
Expand Down
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import factory
import pytest
from pytest_mock import MockerFixture
from rest_framework.test import APIClient
from rest_framework.test import APIClient, APIRequestFactory

from s3_file_field._multipart import MultipartManager
from s3_file_field._sizes import mb
Expand All @@ -27,6 +27,11 @@ def _reduce_part_size(mocker: MockerFixture) -> None:
mocker.patch.object(MultipartManager, "part_size", new=mb(5))


@pytest.fixture()
def request_factory() -> APIRequestFactory:
return APIRequestFactory()


@pytest.fixture()
def api_client() -> APIClient:
return APIClient()
Expand Down
78 changes: 77 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,93 @@
from typing import cast
from unittest.mock import Mock

from django.core import signing
from django.core.files.storage import default_storage
from django.urls import reverse
import pytest
import requests
from rest_framework.test import APIClient
from rest_framework.test import APIClient, APIRequestFactory

from s3_file_field._sizes import mb
from s3_file_field.views import (
get_can_user_upload,
is_site_user,
no_check,
request_passes_test,
upload_initialize,
)

from fuzzy import FUZZY_UPLOAD_ID, FUZZY_URL, Fuzzy


def test_can_user_upload_is_no_check() -> None:
assert get_can_user_upload() == no_check


def test_can_user_upload_is_is_site_user(settings) -> None:
settings.S3_FILE_FIELD_USER_PERMISSION = "s3_file_field.views.is_site_user"
assert get_can_user_upload() == is_site_user


@pytest.mark.parametrize(
"view",
[
"upload-initialize",
"upload-complete",
"finalize",
],
)
def test_no_check(request_factory: APIRequestFactory, view: str) -> None:
request = request_factory.post(
reverse(f"s3_file_field:{view}"),
{
"field_id": "test_app.Resource.blob",
"file_name": "test.txt",
"file_size": 10,
"content_type": "text/plain",
},
format="json",
)
assert no_check(request)


def test_is_site_user_with_authenticated_user(settings, request_factory: APIRequestFactory) -> None:
settings.S3_FILE_FIELD_USER_PERMISSION = "s3_file_field.views.is_site_user"
request = request_factory.post(
reverse("s3_file_field:upload-initialize"),
{
"field_id": "test_app.Resource.blob",
"file_name": "test.txt",
"file_size": 10,
"content_type": "text/plain",
},
format="json",
)
request.user = Mock()
request.user.is_authenticated = True
assert is_site_user(request)


def test_is_site_user_with_anonymous_user(settings, request_factory: APIRequestFactory) -> None:
settings.S3_FILE_FIELD_USER_PERMISSION = "s3_file_field.views.is_site_user"
settings.LOGIN_URL = "/login/"
request = request_factory.post(
reverse("s3_file_field:upload-initialize"),
{
"field_id": "test_app.Resource.blob",
"file_name": "test.txt",
"file_size": 10,
"content_type": "text/plain",
},
format="json",
)
request.user = Mock()
request.user.is_authenticated = False
resp = request_passes_test(is_site_user)(upload_initialize)(request)
assert resp.status_code == 302
assert resp.url.startswith("/login/")


def test_prepare(api_client: APIClient) -> None:
resp = api_client.post(
reverse("s3_file_field:upload-initialize"),
Expand Down