From 94c376709090d1deb3cf277ec11d11e5cb002d4a Mon Sep 17 00:00:00 2001 From: AndrewLester Date: Fri, 28 Apr 2023 14:06:42 -0500 Subject: [PATCH] Add client fetch to auth dependency and create new dependency for async stuff --- pv_site_api/_db_helpers.py | 19 +++++-- pv_site_api/auth.py | 21 ++++++- pv_site_api/cache.py | 2 +- pv_site_api/enode_auth.py | 22 ++++++-- pv_site_api/main.py | 110 ++++++++++++------------------------- tests/conftest.py | 5 +- tests/test_auth.py | 6 +- tests/test_enode.py | 8 +-- tests/test_enode_auth.py | 14 ++--- 9 files changed, 102 insertions(+), 105 deletions(-) diff --git a/pv_site_api/_db_helpers.py b/pv_site_api/_db_helpers.py index fd2a5b9..e6cfb2b 100644 --- a/pv_site_api/_db_helpers.py +++ b/pv_site_api/_db_helpers.py @@ -13,6 +13,7 @@ import sqlalchemy as sa import structlog +from fastapi import Depends from pvsite_datamodel.read.generation import get_pv_generation_by_sites from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, InverterSQL, SiteSQL from sqlalchemy.orm import Session, aliased @@ -24,6 +25,7 @@ PVSiteMetadata, SiteForecastValues, ) +from .session import get_session logger = structlog.stdlib.get_logger() @@ -60,12 +62,6 @@ def _get_forecasts_for_horizon( return list(session.execute(stmt)) -def _get_inverters_by_site(session: Session, site_uuid: str) -> list[Row]: - query = session.query(InverterSQL).filter(InverterSQL.site_uuid == site_uuid) - - return query.all() - - def _get_latest_forecast_by_sites( session: Session, site_uuids: list[str], start_utc: Optional[dt.datetime] = None ) -> list[Row]: @@ -240,3 +236,14 @@ def does_site_exist(session: Session, site_uuid: str) -> bool: session.execute(sa.select(SiteSQL).where(SiteSQL.site_uuid == site_uuid)).one_or_none() is not None ) + + +def get_inverters_for_site( + site_uuid: str, session: Session = Depends(get_session) +) -> list[Row] | None: + """Path dependency to get a list of inverters for a site, or None if the site doesn't exist""" + if not does_site_exist(session, site_uuid): + return None + + query = session.query(InverterSQL).filter(InverterSQL.site_uuid == site_uuid) + return query.all() diff --git a/pv_site_api/auth.py b/pv_site_api/auth.py index a795830..2b80ac8 100644 --- a/pv_site_api/auth.py +++ b/pv_site_api/auth.py @@ -1,6 +1,10 @@ import jwt from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pvsite_datamodel.sqlmodels import ClientSQL +from sqlalchemy.orm import Session + +from .session import get_session token_auth_scheme = HTTPBearer() @@ -15,7 +19,11 @@ def __init__(self, domain: str, api_audience: str, algorithm: str): self._jwks_client = jwt.PyJWKClient(f"https://{domain}/.well-known/jwks.json") - def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(token_auth_scheme)): + def __call__( + self, + auth_credentials: HTTPAuthorizationCredentials = Depends(token_auth_scheme), + session: Session = Depends(get_session), + ): token = auth_credentials.credentials try: @@ -24,7 +32,7 @@ def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(toke raise HTTPException(status_code=401, detail=str(e)) try: - payload = jwt.decode( + jwt.decode( token, signing_key, algorithms=self._algorithm, @@ -34,4 +42,11 @@ def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(toke except Exception as e: raise HTTPException(status_code=401, detail=str(e)) - return payload + if session is None: + return None + + # @TODO: get client corresponding to auth + # See: https://github.com/openclimatefix/pv-site-api/issues/90 + client = session.query(ClientSQL).first() + assert client is not None + return client diff --git a/pv_site_api/cache.py b/pv_site_api/cache.py index f9583fa..fec8618 100644 --- a/pv_site_api/cache.py +++ b/pv_site_api/cache.py @@ -39,7 +39,7 @@ def wrapper(*args, **kwargs): # noqa route_variables = kwargs.copy() # drop session and user - for var in ["session", "user"]: + for var in ["session", "user", "auth"]: if var in route_variables: route_variables.pop(var) diff --git a/pv_site_api/enode_auth.py b/pv_site_api/enode_auth.py index c647c54..76aea97 100644 --- a/pv_site_api/enode_auth.py +++ b/pv_site_api/enode_auth.py @@ -12,7 +12,7 @@ def __init__( self._token_url = token_url self._access_token = access_token - def auth_flow(self, request: httpx.Request): + def sync_auth_flow(self, request: httpx.Request): # Add the Authorization header to the request using the current access token request.headers["Authorization"] = f"Bearer {self._access_token}" response = yield request @@ -20,12 +20,27 @@ def auth_flow(self, request: httpx.Request): if response.status_code == 401: # The access token is no longer valid, refresh it token_response = yield self._build_refresh_request() + token_response.read() self._update_access_token(token_response) # Update the request's Authorization header with the new access token request.headers["Authorization"] = f"Bearer {self._access_token}" # Resend the request with the new access token - response = yield request - return response + yield request + + async def async_auth_flow(self, request: httpx.Request): + # Add the Authorization header to the request using the current access token + request.headers["Authorization"] = f"Bearer {self._access_token}" + response = yield request + + if response.status_code == 401: + # The access token is no longer valid, refresh it + token_response = yield self._build_refresh_request() + await token_response.aread() + self._update_access_token(token_response) + # Update the request's Authorization header with the new access token + request.headers["Authorization"] = f"Bearer {self._access_token}" + # Resend the request with the new access token + yield request def _build_refresh_request(self): basic_auth = httpx.BasicAuth(self._client_id, self._client_secret) @@ -35,5 +50,4 @@ def _build_refresh_request(self): return request def _update_access_token(self, response): - response.read() self._access_token = response.json()["access_token"] diff --git a/pv_site_api/main.py b/pv_site_api/main.py index 7379393..acce48d 100644 --- a/pv_site_api/main.py +++ b/pv_site_api/main.py @@ -1,9 +1,9 @@ """Main API Routes""" import os +from typing import Any import httpx import pandas as pd -import sentry_sdk import structlog from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException @@ -20,10 +20,10 @@ import pv_site_api from ._db_helpers import ( - _get_inverters_by_site, does_site_exist, get_forecasts_by_sites, get_generation_by_sites, + get_inverters_for_site, get_sites_by_uuids, site_to_pydantic, ) @@ -82,11 +82,11 @@ def is_fake(): return int(os.environ.get("FAKE", 0)) -sentry_sdk.init( - dsn=os.getenv("SENTRY_DSN"), - environment=os.getenv("ENVIRONMENT", "local"), - traces_sampler=traces_sampler, -) +# sentry_sdk.init( +# dsn=os.getenv("SENTRY_DSN"), +# environment=os.getenv("ENVIRONMENT", "local"), +# traces_sampler=traces_sampler, +# ) app = FastAPI(docs_url="/swagger", redoc_url=None) @@ -133,7 +133,7 @@ def is_fake(): @app.get("/sites", response_model=PVSites) def get_sites( session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """ ### This route returns a list of the user's PV Sites with metadata for each site. @@ -161,7 +161,7 @@ def post_pv_actual( site_uuid: str, pv_actual: MultiplePVActual, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """### This route is used to input actual PV generation. @@ -198,7 +198,7 @@ def put_site_info( site_uuid: str, site_info: PVSiteMetadata, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """ ### This route allows a user to update a site's information. @@ -209,15 +209,8 @@ def put_site_info( print(f"Fake: would update site {site_uuid} with {site_info.dict()}") return - # @TODO: get client corresponding to auth - # See: https://github.com/openclimatefix/pv-site-api/issues/90 - client = session.query(ClientSQL).first() - assert client is not None - site = ( - session.query(SiteSQL) - .filter_by(client_uuid=client.client_uuid, site_uuid=site_uuid) - .first() + session.query(SiteSQL).filter_by(client_uuid=auth.client_uuid, site_uuid=site_uuid).first() ) if site is None: raise HTTPException(status_code=404, detail="Site not found") @@ -241,7 +234,7 @@ def put_site_info( def post_site_info( site_info: PVSiteMetadata, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """ ### This route allows a user to add a site. @@ -253,13 +246,8 @@ def post_site_info( print("Not doing anything with it (yet!)") return - # @TODO: get client corresponding to auth - # See: https://github.com/openclimatefix/pv-site-api/issues/90 - client = session.query(ClientSQL).first() - assert client is not None - site = SiteSQL( - client_uuid=client.client_uuid, + client_uuid=auth.client_uuid, client_site_id=site_info.client_site_id, client_site_name=site_info.client_site_name, region=site_info.region, @@ -283,7 +271,7 @@ def post_site_info( def get_pv_actual( site_uuid: str, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """### This route returns PV readings from a single PV site. @@ -300,7 +288,7 @@ def get_pv_actual( def get_pv_actual_many_sites( site_uuids: str, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """ ### Get the actual power generation for a list of sites. @@ -320,7 +308,7 @@ def get_pv_actual_many_sites( def get_pv_forecast( site_uuid: str, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """ ### This route is where you might say the magic happens. @@ -355,12 +343,11 @@ def get_pv_forecast( def get_pv_forecast_many_sites( site_uuids: str, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: ClientSQL = Depends(auth), ): """ ### Get the forecasts for multiple sites. """ - logger.info(f"Getting forecasts for {site_uuids}") if is_fake(): @@ -383,7 +370,7 @@ def get_pv_forecast_many_sites( def get_pv_estimate_clearsky( site_uuid: str, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """ ### Gets a estimate of AC production under a clear sky @@ -401,7 +388,7 @@ def get_pv_estimate_clearsky( def get_pv_estimate_clearsky_many_sites( site_uuids: str, session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """ ### Gets a estimate of AC production under a clear sky for multiple sites. @@ -462,8 +449,7 @@ def get_pv_estimate_clearsky_many_sites( @app.get("/enode/link") def get_enode_link( redirect_uri: str, - session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): """ ### Returns a URL from Enode that starts a user's Enode link flow. @@ -471,14 +457,9 @@ def get_enode_link( if is_fake(): return make_fake_enode_link_url() - # @TODO: get client corresponding to auth - # See: https://github.com/openclimatefix/pv-site-api/issues/90 - client = session.query(ClientSQL).first() - assert client is not None - with httpx.Client(base_url=enode_api_base_url, auth=enode_auth) as httpx_client: data = {"vendorType": "inverter", "redirectUri": redirect_uri} - res = httpx_client.post(f"/users/{client.client_uuid}/link", data=data).json() + res = httpx_client.post(f"/users/{auth.client_uuid}/link", data=data).json() return res["linkUrl"] @@ -486,50 +467,33 @@ def get_enode_link( @app.get("/enode/inverters") async def get_inverters( session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): if is_fake(): return make_fake_inverters() - # @TODO: get client corresponding to auth - # See: https://github.com/openclimatefix/pv-site-api/issues/90 - client = session.query(ClientSQL).first() - assert client is not None - async with httpx.AsyncClient(base_url=enode_api_base_url, auth=enode_auth) as httpx_client: - headers = {"Enode-User-Id": str(client.client_uuid)} + headers = {"Enode-User-Id": str(auth.client_uuid)} response_json = (await httpx_client.get("/inverters", headers=headers)).json() inverter_ids = [str(inverter_id) for inverter_id in response_json] - return await get_inverters_list( - client.client_uuid, inverter_ids, enode_auth, enode_api_base_url - ) + return await get_inverters_list(auth.client_uuid, inverter_ids, enode_auth, enode_api_base_url) @app.get("/sites/{site_uuid}/inverters") -async def get_inverters_for_site( - site_uuid: str, - session: Session = Depends(get_session), - auth: Auth = Depends(auth), +async def get_inverters_data_for_site( + inverters: list[Any] | None = Depends(get_inverters_for_site), + auth: Any = Depends(auth), ): if is_fake(): return make_fake_inverters() - site_exists = does_site_exist(session, site_uuid) - - if not site_exists: + if inverters is None: raise HTTPException(status_code=404) - # @TODO: get client corresponding to auth - # See: https://github.com/openclimatefix/pv-site-api/issues/90 - client = session.query(ClientSQL).first() - assert client is not None + inverter_ids = [inverter.client_id for inverter in inverters] - inverter_ids = [inverter.client_id for inverter in _get_inverters_by_site(session, site_uuid)] - - return await get_inverters_list( - client.client_uuid, inverter_ids, enode_auth, enode_api_base_url - ) + return await get_inverters_list(auth.client_uuid, inverter_ids, enode_auth, enode_api_base_url) @app.put("/sites/{site_uuid}/inverters") @@ -537,22 +501,18 @@ def put_inverters_for_site( site_uuid: str, client_ids: list[str], session: Session = Depends(get_session), - auth: Auth = Depends(auth), + auth: Any = Depends(auth), ): + """ + ### Updates a site's inverters with a list of inverter client ids (`client_ids`) + """ if is_fake(): print(f"Successfully changed inverters for {site_uuid}") print("Not doing anything with it (yet!)") return - # @TODO: get client corresponding to auth - # See: https://github.com/openclimatefix/pv-site-api/issues/90 - client = session.query(ClientSQL).first() - assert client is not None - site = ( - session.query(SiteSQL) - .filter_by(client_uuid=client.client_uuid, site_uuid=site_uuid) - .first() + session.query(SiteSQL).filter_by(client_uuid=auth.client_uuid, site_uuid=site_uuid).first() ) if site is None: raise HTTPException(status_code=404, detail="Site not found") diff --git a/tests/conftest.py b/tests/conftest.py index 85b5731..bb55e3d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,7 +108,6 @@ def sites(db_session, clients): @pytest.fixture() def inverters(db_session, sites): """Create some fake inverters for site 0""" - inverters = [] num_inverters = 3 inverters = [ InverterSQL(site_uuid=sites[0].site_uuid, client_id=f"id{j+1}") @@ -210,7 +209,7 @@ def forecast_values(db_session, sites): @pytest.fixture() -def client(db_session): +def client(db_session, clients): app.dependency_overrides[get_session] = lambda: db_session - app.dependency_overrides[auth] = lambda: None + app.dependency_overrides[auth] = lambda: clients[0] return TestClient(app) diff --git a/tests/test_auth.py b/tests/test_auth.py index bb72625..5cc98e4 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -10,6 +10,7 @@ from fastapi.testclient import TestClient from pv_site_api.auth import Auth +from pv_site_api.session import get_session # Use symetric algo for simplicity. ALGO = "HS256" @@ -37,10 +38,11 @@ def get_signing_key_from_jwt(self, token): @pytest.fixture -def trivial_client(auth): +def trivial_client(db_session, auth): """A client with only one restricted route.""" app = FastAPI() + app.dependency_overrides[get_session] = lambda: db_session # Add a route that depends on authorization. @app.get("/route", dependencies=[Depends(auth)]) @@ -54,7 +56,7 @@ def _make_header(token): return {"Authorization": f"Bearer {token}"} -def test_auth_happy_path(trivial_client): +def test_auth_happy_path(clients, trivial_client): token = jwt.encode( {"aud": API_AUDIENCE, "iss": f"https://{DOMAIN}/"}, SECRET, diff --git a/tests/test_enode.py b/tests/test_enode.py index 4e79199..7721d93 100644 --- a/tests/test_enode.py +++ b/tests/test_enode.py @@ -7,8 +7,8 @@ def test_get_enode_link_fake(client, fake): params = {"redirect_uri": "https://example.org"} response = client.get("/enode/link", params=params, follow_redirects=False) - assert response.status_code == 307 - assert len(response.headers["location"]) > 0 + assert response.status_code == 200 + assert len(response.json()) > 0 def test_get_enode_link(client, clients, httpx_mock): @@ -26,5 +26,5 @@ def test_get_enode_link(client, clients, httpx_mock): follow_redirects=False, ) - assert response.status_code == 307 - assert response.headers["location"] == test_enode_link_uri + assert response.status_code == 200 + assert response.json() == test_enode_link_uri diff --git a/tests/test_enode_auth.py b/tests/test_enode_auth.py index 93c5569..a13831b 100644 --- a/tests/test_enode_auth.py +++ b/tests/test_enode_auth.py @@ -10,7 +10,7 @@ CLIENT_ID = "ocf" CLIENT_SECRET = "secret" -enode_base_url = "https://enode.com/api" +test_enode_base_url = "https://enode.com/api" @pytest.fixture @@ -20,9 +20,9 @@ def enode_auth(): return enode_auth -def test_enode_auth(enode_auth): - request = httpx.Request("GET", f"{enode_base_url}/inverters") - gen = enode_auth.auth_flow(request) +def test_enode_auth_sync(enode_auth): + request = httpx.Request("GET", f"{test_enode_base_url}/inverters") + gen = enode_auth.sync_auth_flow(request) authenticated_request = next(gen) assert authenticated_request.headers["Authorization"] == "Bearer None" @@ -38,9 +38,9 @@ def test_enode_auth(enode_auth): assert authenticated_request.headers["Authorization"] == f"Bearer {test_access_token}" try: - gen.send(httpx.Response(200, json=["id1"])) - except StopIteration as e: - assert isinstance(e.value, httpx.Response) and e.value.json()[0] == "id1" + next(gen) + except StopIteration: + pass else: # The generator should exit assert False