diff --git a/poetry.lock b/poetry.lock index 2d56edd..8710d92 100644 --- a/poetry.lock +++ b/poetry.lock @@ -747,6 +747,24 @@ cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (>=1.0.0,<2.0.0)"] +[[package]] +name = "httpx-auth" +version = "0.17.0" +description = "Authentication for HTTPX" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx_auth-0.17.0-py3-none-any.whl", hash = "sha256:5358f2938f8843179dc681cea34626d3589b312bb021425f2cd4a4fbc316e92c"}, + {file = "httpx_auth-0.17.0.tar.gz", hash = "sha256:4e297113804ac3ee316d12a9596bc05e4dd592d2bf0809e5b4dab496d8a35b13"}, +] + +[package.dependencies] +httpx = ">=0.24.0,<0.25.0" + +[package.extras] +testing = ["pyjwt (>=2.0.0,<3.0.0)", "pytest-cov (>=4.0.0,<5.0.0)", "pytest-httpx (>=0.22.0,<0.23.0)"] + [[package]] name = "idna" version = "3.4" @@ -2171,4 +2189,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "10e0de9302cd41354256267e337c7355c47dd8a30123480d24e0a4766d35f967" +content-hash = "871a7e84a57d37e04219e9fb9388a29678bff03ddf5ddb8bb2ac8acee74b4f17" diff --git a/pv_site_api/enode_auth.py b/pv_site_api/enode_auth.py deleted file mode 100644 index 76aea97..0000000 --- a/pv_site_api/enode_auth.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional - -import httpx - - -class EnodeAuth(httpx.Auth): - def __init__( - self, client_id: str, client_secret: str, token_url: str, access_token: Optional[str] = None - ): - self._client_id = client_id - self._client_secret = client_secret - self._token_url = token_url - self._access_token = access_token - - def sync_auth_flow(self, request: httpx.Request): - # Add the Authorization header to the request using the current access token - request.headers["Authorization"] = f"Bearer {self._access_token}" - response = yield request - - if response.status_code == 401: - # The access token is no longer valid, refresh it - token_response = yield self._build_refresh_request() - token_response.read() - self._update_access_token(token_response) - # Update the request's Authorization header with the new access token - request.headers["Authorization"] = f"Bearer {self._access_token}" - # Resend the request with the new access token - yield request - - async def async_auth_flow(self, request: httpx.Request): - # Add the Authorization header to the request using the current access token - request.headers["Authorization"] = f"Bearer {self._access_token}" - response = yield request - - if response.status_code == 401: - # The access token is no longer valid, refresh it - token_response = yield self._build_refresh_request() - await token_response.aread() - self._update_access_token(token_response) - # Update the request's Authorization header with the new access token - request.headers["Authorization"] = f"Bearer {self._access_token}" - # Resend the request with the new access token - yield request - - def _build_refresh_request(self): - basic_auth = httpx.BasicAuth(self._client_id, self._client_secret) - - data = {"grant_type": "client_credentials"} - request = next(basic_auth.auth_flow(httpx.Request("POST", self._token_url, data=data))) - return request - - def _update_access_token(self, response): - self._access_token = response.json()["access_token"] diff --git a/pv_site_api/main.py b/pv_site_api/main.py index 0bcd1ca..8d66d7a 100644 --- a/pv_site_api/main.py +++ b/pv_site_api/main.py @@ -3,6 +3,7 @@ from typing import Any import httpx +from httpx_auth import OAuth2ClientCredentials import pandas as pd import sentry_sdk import structlog @@ -30,7 +31,6 @@ ) from .auth import Auth from .cache import cache_response -from .enode_auth import EnodeAuth from .fake import ( fake_site_uuid, make_fake_enode_link_url, @@ -113,10 +113,10 @@ def is_fake(): algorithm=os.getenv("AUTH0_ALGORITHM"), ) -enode_auth = EnodeAuth( - os.getenv("ENODE_CLIENT_ID", ""), - os.getenv("ENODE_CLIENT_SECRET", ""), - os.getenv("ENODE_TOKEN_URL", "https://oauth.sandbox.enode.io/oauth2/token"), +enode_auth = OAuth2ClientCredentials( + token_url=os.getenv("ENODE_TOKEN_URL", "https://oauth.sandbox.enode.io/oauth2/token"), + client_id=os.getenv("ENODE_CLIENT_ID", ""), + client_secret=os.getenv("ENODE_CLIENT_SECRET", ""), ) enode_api_base_url = os.getenv("ENODE_API_BASE_URL", "https://enode-api.sandbox.enode.io") diff --git a/pv_site_api/utils.py b/pv_site_api/utils.py index 2b6f379..5f77301 100644 --- a/pv_site_api/utils.py +++ b/pv_site_api/utils.py @@ -7,14 +7,13 @@ import httpx -from .enode_auth import EnodeAuth from .pydantic_models import Inverters, InverterValues TOTAL_MINUTES_IN_ONE_DAY = 24 * 60 async def get_inverters_list( - client_uuid: uuid.UUID, inverter_ids: list[str], enode_auth: EnodeAuth, enode_api_base_url: str + client_uuid: uuid.UUID, inverter_ids: list[str], enode_auth: httpx.Auth, enode_api_base_url: str ) -> Inverters: async with httpx.AsyncClient(base_url=enode_api_base_url, auth=enode_auth) as httpx_client: headers = {"Enode-User-Id": str(client_uuid)} diff --git a/pyproject.toml b/pyproject.toml index c773402..51b617f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ sentry-sdk = "^1.16.0" pvlib = "^0.9.5" structlog = "^22.3.0" pyjwt = {extras = ["crypto"], version = "^2.6.0"} +httpx-auth = "^0.17.0" [tool.poetry.group.dev.dependencies] isort = "^5.12.0" diff --git a/tests/conftest.py b/tests/conftest.py index e6a7b71..b71e585 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import os from datetime import datetime, timedelta +import pytest_httpx import freezegun import pytest from fastapi.testclient import TestClient @@ -23,6 +24,9 @@ from pv_site_api.session import get_session +enode_token_url = os.getenv("ENODE_TOKEN_URL", "https://oauth.sandbox.enode.io/oauth2/token") + + @pytest.fixture def non_mocked_hosts() -> list: """Prevent TestClient fixture from being mocked""" @@ -71,6 +75,16 @@ def db_session(engine): engine.dispose() +@pytest.fixture() +def mock_enode_auth(httpx_mock): + """Adds mocked response for Enode authentication""" + httpx_mock.add_response( + url=enode_token_url, + # Ensure token expires immediately so that every test must go through Enode auth + json={"access_token": "test.test", "expires_in": 1, "scope": "", "token_type": "bearer"}, + ) + + @pytest.fixture() def clients(db_session): """Make fake client sql""" diff --git a/tests/test_enode.py b/tests/test_enode.py index 7721d93..d49c63e 100644 --- a/tests/test_enode.py +++ b/tests/test_enode.py @@ -11,7 +11,7 @@ def test_get_enode_link_fake(client, fake): assert len(response.json()) > 0 -def test_get_enode_link(client, clients, httpx_mock): +def test_get_enode_link(client, clients, httpx_mock, mock_enode_auth): test_enode_link_uri = "https://example.com" httpx_mock.add_response( diff --git a/tests/test_enode_auth.py b/tests/test_enode_auth.py deleted file mode 100644 index a13831b..0000000 --- a/tests/test_enode_auth.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Test the Enode authentication HTTPX auth class. -""" -import httpx -import pytest - -from pv_site_api.enode_auth import EnodeAuth - -TOKEN_URL = "https://example.com/token" -CLIENT_ID = "ocf" -CLIENT_SECRET = "secret" - -test_enode_base_url = "https://enode.com/api" - - -@pytest.fixture -def enode_auth(): - """An Enode Auth object""" - enode_auth = EnodeAuth(token_url=TOKEN_URL, client_id=CLIENT_ID, client_secret=CLIENT_SECRET) - return enode_auth - - -def test_enode_auth_sync(enode_auth): - request = httpx.Request("GET", f"{test_enode_base_url}/inverters") - gen = enode_auth.sync_auth_flow(request) - authenticated_request = next(gen) - assert authenticated_request.headers["Authorization"] == "Bearer None" - - refresh_request = gen.send(httpx.Response(401)) - assert ( - refresh_request.method == "POST" - and refresh_request.url == httpx.URL(TOKEN_URL) - and refresh_request.content == b"grant_type=client_credentials" - ) - - test_access_token = "test_access_token" - authenticated_request = gen.send(httpx.Response(200, json={"access_token": test_access_token})) - assert authenticated_request.headers["Authorization"] == f"Bearer {test_access_token}" - - try: - next(gen) - except StopIteration: - pass - else: - # The generator should exit - assert False diff --git a/tests/test_inverters.py b/tests/test_inverters.py index 85d3496..cab9b73 100644 --- a/tests/test_inverters.py +++ b/tests/test_inverters.py @@ -14,7 +14,7 @@ def test_put_inverters_for_site_fake(client, sites, fake): assert response.status_code == 200 -def test_put_inverters_for_site(client, sites, httpx_mock): +def test_put_inverters_for_site(client, sites, httpx_mock, mock_enode_auth): test_inverter_client_id = "6c078ca2-2e75-40c8-9a7f-288bd0b70065" json = [test_inverter_client_id] response = client.put(f"/sites/{sites[0].site_uuid}/inverters", json=json) @@ -40,7 +40,7 @@ def test_get_inverters_for_site_fake(client, sites, inverters, fake): assert response.status_code == 200 -def test_get_inverters_for_site(client, sites, inverters, httpx_mock): +def test_get_inverters_for_site(client, sites, inverters, httpx_mock, mock_enode_auth): mock_inverter_response("id1", httpx_mock) mock_inverter_response("id2", httpx_mock) mock_inverter_response("id3", httpx_mock) @@ -66,7 +66,7 @@ def test_get_enode_inverters_fake(client, fake): assert len(response_inverters.inverters) > 0 -def test_get_enode_inverters(client, httpx_mock, clients): +def test_get_enode_inverters(client, httpx_mock, clients, mock_enode_auth): httpx_mock.add_response(url=f"{enode_api_base_url}/inverters", json=["id1"]) mock_inverter_response("id1", httpx_mock)