From 5ecb65685fd6cee2eff0ea4d4f1dee7a6525dbea Mon Sep 17 00:00:00 2001 From: ZachPrice Date: Thu, 17 Oct 2024 23:40:24 -0700 Subject: [PATCH] new auth stuff new functions --- app/auth.py | 19 ++++++++++++-- app/dependencies.py | 42 +++++++++++++++++++++++++++++ app/models.py | 4 +-- app/post.py | 10 ++++--- app/schemas.py | 6 ++++- app/user.py | 29 ++++++++++++++++++-- app/vote.py | 64 +++++++++++++++++++++++++++++++++++++++++---- 7 files changed, 158 insertions(+), 16 deletions(-) create mode 100644 app/dependencies.py diff --git a/app/auth.py b/app/auth.py index 21ba66a..0b08fd6 100644 --- a/app/auth.py +++ b/app/auth.py @@ -1,6 +1,7 @@ # app/api/auth.py -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, HTTPException, status, Depends +from fastapi.security import OAuth2PasswordBearer from datetime import timedelta, datetime from app.models import User from app.schemas import UserCreate, UserResponse @@ -16,7 +17,8 @@ SECRET_KEY = os.getenv("SECRET_KEY", "default_secret_key") ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 +ACCESS_TOKEN_EXPIRE_MINUTES = 180 +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") def create_access_token(data: dict, expires_delta: timedelta = None): to_encode = data.copy() @@ -28,6 +30,19 @@ def create_access_token(data: dict, expires_delta: timedelta = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt +@auth_router.get("/verify-token") +async def verify_token(token: str = Depends(oauth2_scheme)): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + wallet_address = payload.get("sub") + if wallet_address is None: + raise HTTPException(status_code=401, detail="Invalid token") + return {"message": "Token is valid", "wallet_address": wallet_address} + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token has expired") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token") + @auth_router.post("/signup", response_model=UserResponse) async def signup(user: UserCreate): # Check if the wallet address already exists diff --git a/app/dependencies.py b/app/dependencies.py new file mode 100644 index 0000000..77cdec6 --- /dev/null +++ b/app/dependencies.py @@ -0,0 +1,42 @@ +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from app.models import User +import os +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +SECRET_KEY = os.getenv("SECRET_KEY", "default_secret_key") +ALGORITHM = "HS256" + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO) + +def get_current_user(token: str = Depends(oauth2_scheme)): + logging.info("Decoding JWT token") + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + wallet_address: str = payload.get("sub") + if wallet_address is None: + logging.error("Wallet address not found in token") + raise credentials_exception + except JWTError as e: + logging.error(f"JWT decoding error: {e}") + raise credentials_exception + user = User.objects(wallet_address=wallet_address).first() + if user is None: + logging.error("User not found for wallet address") + raise credentials_exception + logging.info(f"User {user.wallet_address} authenticated successfully") + return user \ No newline at end of file diff --git a/app/models.py b/app/models.py index 8784535..15f6e74 100644 --- a/app/models.py +++ b/app/models.py @@ -31,8 +31,8 @@ class Post(Model): class Vote(Model): __keyspace__ = 'store' id = columns.UUID(primary_key=True, default=uuid4) - post_id = columns.UUID(required=True) - user_id = columns.UUID(required=True) + post_id = columns.UUID(required=True, index=True) # Ensure this field is indexed + user_id = columns.UUID(required=True, index=True) # Ensure this field is indexed vote_value = columns.Integer(required=True) created_at = columns.DateTime() diff --git a/app/post.py b/app/post.py index 169cf0a..cf7bd3b 100644 --- a/app/post.py +++ b/app/post.py @@ -1,10 +1,12 @@ -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, HTTPException, Query, Depends from app.models import Post, Vote # Ensure Vote is imported from app.schemas import PostCreate, PostUpdate, PostResponse from uuid import uuid4 from typing import List +from app.models import User from datetime import datetime import logging +from app.dependencies import get_current_user # Configure logging logging.basicConfig(level=logging.INFO) @@ -12,7 +14,7 @@ post_router = APIRouter() -VOTE_THRESHOLD = -100 # Define the vote threshold +VOTE_THRESHOLD = -20 # Define the vote threshold @post_router.get("/", response_model=List[PostResponse]) def read_all_posts(page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=100)): @@ -63,10 +65,10 @@ def read_all_posts(page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, return post_responses @post_router.post("/", response_model=PostResponse) -def create_post(post: PostCreate): +def create_post(post: PostCreate, current_user: User = Depends(get_current_user)): new_post = Post( id=uuid4(), - user_id=post.user_id, + user_id=current_user.id, # Use the authenticated user's ID title=post.title, content=post.content, created_at=datetime.now(), diff --git a/app/schemas.py b/app/schemas.py index 2373a8c..63508dc 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -14,6 +14,11 @@ class UserCreate(UserBase): signature: str challenge: str +class UserProfileResponse(BaseModel): + profile_photo_url: str + wallet_address: str + display_name: str + class UserResponse(BaseModel): id: UUID wallet_address: str @@ -57,7 +62,6 @@ class Config: # Vote Schemas class VoteBase(BaseModel): post_id: UUID - user_id: UUID vote_value: int class VoteCreate(VoteBase): diff --git a/app/user.py b/app/user.py index 6ce1019..54c54c5 100644 --- a/app/user.py +++ b/app/user.py @@ -1,16 +1,41 @@ import os from fastapi import APIRouter, HTTPException -from app.models import User -from app.schemas import UserCreate, UserResponse +from app.models import User, Post +from app.schemas import UserCreate, UserResponse, UserProfileResponse from uuid import UUID from datetime import datetime, timedelta from fastapi.encoders import jsonable_encoder user_router = APIRouter() +@user_router.get("/test") +def test_route(): + return {"message": "Test route is working"} + @user_router.get("/{user_id}", response_model=UserResponse) def read_user(user_id: UUID): db_user = User.objects(id=user_id).first() if db_user is None: raise HTTPException(status_code=404, detail="User not found") return db_user + +@user_router.get("/profile-from-post/{post_id}", response_model=UserProfileResponse) +def get_user_profile_from_post(post_id: UUID): + # Retrieve all posts + all_posts = Post.objects.all() + post = next((p for p in all_posts if p.id == post_id), None) + if post is None: + raise HTTPException(status_code=404, detail="Post not found") + + # Retrieve all users + all_users = User.objects.all() + user = next((u for u in all_users if u.id == post.user_id), None) + if user is None: + raise HTTPException(status_code=404, detail="User not found") + + # Return the user's profile photo URL, wallet address, and display name + return { + "profile_photo_url": user.profile_photo_url, + "wallet_address": user.wallet_address, + "display_name": user.display_name + } diff --git a/app/vote.py b/app/vote.py index 70acbd0..979c90d 100644 --- a/app/vote.py +++ b/app/vote.py @@ -1,14 +1,41 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Depends from app.models import Vote from app.schemas import VoteCreate, VoteResponse from uuid import UUID +from app.dependencies import get_current_user +from app.models import User +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO) vote_router = APIRouter() +# @vote_router.post("/", response_model=VoteResponse) +# def create_vote(vote: VoteCreate, current_user: User = Depends(get_current_user)): +# db_vote = Vote.create(**vote.dict()) +# return db_vote + + @vote_router.post("/", response_model=VoteResponse) -def create_vote(vote: VoteCreate): - db_vote = Vote.create(**vote.dict()) - return db_vote +def toggle_vote(vote: VoteCreate, current_user: User = Depends(get_current_user)): + user = User.objects(wallet_address=current_user.wallet_address).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Retrieve all votes for the post and filter in Python + all_votes = Vote.objects(post_id=vote.post_id).all() + existing_vote = next((v for v in all_votes if v.user_id == user.id), None) + + if existing_vote: + # Determine the adjustment needed when removing the vote + adjustment = -existing_vote.vote_value + existing_vote.delete() + # Return a response indicating the vote was removed and the adjustment + return VoteResponse(id=existing_vote.id, post_id=existing_vote.post_id, vote_value=adjustment) + else: + db_vote = Vote.create(post_id=vote.post_id, vote_value=vote.vote_value, user_id=user.id) + return db_vote @vote_router.get("/{vote_id}", response_model=VoteResponse) def read_vote(vote_id: UUID): @@ -17,8 +44,35 @@ def read_vote(vote_id: UUID): raise HTTPException(status_code=404, detail="Vote not found") return db_vote +@vote_router.get("/post/{post_id}/votes", response_model=int) +def get_post_vote_value(post_id: UUID): + votes = Vote.objects(post_id=post_id).all() + total_vote_value = sum(vote.vote_value for vote in votes) + return total_vote_value + +@vote_router.get("/post/{post_id}/user-vote", response_model=dict) +def get_user_vote_on_post(post_id: UUID, current_user: User = Depends(get_current_user)): + logging.info(f"Fetching user with wallet address: {current_user.wallet_address}") + user = User.objects(wallet_address=current_user.wallet_address).first() + if not user: + logging.error("User not found") + raise HTTPException(status_code=404, detail="User not found") + + logging.info(f"User found: {user.id}, fetching all votes for post: {post_id}") + all_votes = Vote.objects(post_id=post_id).all() + + # Filter votes in Python + user_vote = next((vote for vote in all_votes if vote.user_id == user.id), None) + + if user_vote is None: + logging.info("No vote found for user on this post") + return {"vote_value": 0} + + logging.info(f"Vote found: {user_vote.vote_value}") + return {"vote_value": user_vote.vote_value} + @vote_router.delete("/{vote_id}", response_model=VoteResponse) -def delete_vote(vote_id: UUID): +def delete_vote(vote_id: UUID, current_user: User = Depends(get_current_user)): db_vote = Vote.objects(id=vote_id).first() if db_vote is None: raise HTTPException(status_code=404, detail="Vote not found")