Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
Pinewood One Backend - Main Application Entry Point
"""
import os
from flask import Flask
from flask import Flask, jsonify
from flask_cors import CORS
from config import Config
from db.init import init_db
from extensions import limiter

# Import blueprints
from auth.routes import auth_bp
from api.routes import api_bp
from schoology.routes import oauth_bp as schoology_oauth_bp, schoology_api_bp
from mobile.routes import mobile_bp


def create_app():
Expand All @@ -27,6 +29,10 @@ def create_app():

# CORS configuration
CORS(app, origins=[Config.FRONTEND_URL, "http://localhost:3112"], supports_credentials=True)

# Rate limiter
app.config["RATELIMIT_STORAGE_URI"] = Config.RATELIMIT_STORAGE_URI
limiter.init_app(app)

# Validate configuration
Config.validate()
Expand All @@ -39,6 +45,11 @@ def create_app():
app.register_blueprint(api_bp)
app.register_blueprint(schoology_oauth_bp)
app.register_blueprint(schoology_api_bp)
app.register_blueprint(mobile_bp)

@app.errorhandler(429)
def handle_rate_limit(_error):
return jsonify({"error": "rate_limited"}), 429

return app

Expand Down
151 changes: 97 additions & 54 deletions auth/jwt_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
"""
JWT utilities for Convex authentication using RS256
JWT utilities for backend-issued RS256 tokens.
"""
import os
import json
import base64
import secrets
from datetime import datetime, timedelta, timezone

import jwt
from datetime import datetime, timedelta
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend

from config import Config

# JWT configuration
JWT_ALGORITHM = "RS256"
JWT_EXPIRATION_HOURS = 24
JWT_ISSUER = os.environ.get("JWT_ISSUER", Config.BACKEND_URL)
JWT_AUDIENCE = "convex"
JWT_KEY_ID = "pinewood-one-key-1"
JWT_CONVEX_AUDIENCE = "convex"
JWT_MOBILE_AUDIENCE = "mobile_api"
JWT_DEFAULT_CONVEX_EXPIRATION_HOURS = 24

# RSA key paths
PRIVATE_KEY_PATH = os.path.join(os.path.dirname(__file__), "..", "keys", "private.pem")
Expand All @@ -30,57 +33,52 @@ def _ensure_keys_exist():
os.makedirs(keys_dir)

if not os.path.exists(PRIVATE_KEY_PATH) or not os.path.exists(PUBLIC_KEY_PATH):
# Generate new RSA key pair
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
backend=default_backend(),
)
public_key = private_key.public_key()

# Save private key
with open(PRIVATE_KEY_PATH, "wb") as f:
f.write(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
))
f.write(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
)

# Save public key
with open(PUBLIC_KEY_PATH, "wb") as f:
f.write(public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
))
f.write(
public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
)

print("Generated new RSA key pair for JWT signing")


def _load_private_key():
"""Load the RSA private key."""
_ensure_keys_exist()
with open(PRIVATE_KEY_PATH, "rb") as f:
return serialization.load_pem_private_key(f.read(), password=None)


def _load_public_key():
"""Load the RSA public key."""
_ensure_keys_exist()
with open(PUBLIC_KEY_PATH, "rb") as f:
return serialization.load_pem_public_key(f.read())


def _int_to_base64url(n: int) -> str:
"""Convert an integer to base64url encoding."""
byte_length = (n.bit_length() + 7) // 8
return base64.urlsafe_b64encode(n.to_bytes(byte_length, "big")).rstrip(b"=").decode("ascii")


def get_jwks() -> dict:
"""
Get the JSON Web Key Set for the public key.
This is used by Convex to verify JWT signatures.
"""
"""Get JSON Web Key Set for public key verification."""
public_key = _load_public_key()
public_numbers = public_key.public_numbers()

Expand All @@ -98,60 +96,105 @@ def get_jwks() -> dict:
}


def create_convex_token(user_id: int, email: str, name: str) -> str:
"""
Create a JWT token for Convex authentication.

Args:
user_id: The internal user ID
email: User's email address
name: User's display name

Returns:
JWT token string
"""
def create_token(
user_id: int,
email: str,
name: str,
audience: str,
expires_in_seconds: int | None = None,
extra_claims: dict | None = None,
) -> str:
"""Create a signed JWT for the given audience."""
private_key = _load_private_key()

now = datetime.utcnow()
if expires_in_seconds is None:
if audience == JWT_CONVEX_AUDIENCE:
expires_in_seconds = JWT_DEFAULT_CONVEX_EXPIRATION_HOURS * 3600
else:
expires_in_seconds = Config.MOBILE_ACCESS_TOKEN_TTL_SECONDS

now = datetime.now(timezone.utc)
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function uses datetime.now(timezone.utc) while other parts of the codebase use mobile_db.utcnow() which returns datetime.now(timezone.utc). Consider using the centralized mobile_db.utcnow() helper for consistency, or create a shared utility function to avoid duplicating this pattern.

Copilot uses AI. Check for mistakes.
payload = {
"sub": str(user_id),
"email": email,
"name": name,
"iss": JWT_ISSUER,
"aud": JWT_AUDIENCE,
"aud": audience,
"iat": now,
"exp": now + timedelta(hours=JWT_EXPIRATION_HOURS)
"exp": now + timedelta(seconds=expires_in_seconds),
}

if extra_claims:
payload.update(extra_claims)

headers = {
"kid": JWT_KEY_ID,
"typ": "JWT",
"alg": JWT_ALGORITHM
"alg": JWT_ALGORITHM,
}

return jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM, headers=headers)


def verify_convex_token(token: str) -> dict | None:
"""
Verify and decode a JWT token.

Args:
token: JWT token string

Returns:
Decoded payload dict or None if invalid
"""
def create_convex_token(
user_id: int,
email: str,
name: str,
expires_in_seconds: int | None = None,
) -> str:
"""Backwards-compatible helper for Convex JWT creation."""
return create_token(
user_id=user_id,
email=email,
name=name,
audience=JWT_CONVEX_AUDIENCE,
expires_in_seconds=expires_in_seconds,
)


def create_mobile_access_token(
user_id: int,
email: str,
name: str,
device_id: str,
expires_in_seconds: int | None = None,
) -> str:
"""Create a mobile API access token."""
return create_token(
user_id=user_id,
email=email,
name=name,
audience=JWT_MOBILE_AUDIENCE,
expires_in_seconds=expires_in_seconds or Config.MOBILE_ACCESS_TOKEN_TTL_SECONDS,
extra_claims={
"device_id": device_id,
"jti": secrets.token_hex(16),
},
)


def verify_token(token: str, audience: str) -> dict | None:
"""Verify and decode a JWT for the expected audience."""
try:
public_key = _load_public_key()
return jwt.decode(
token,
public_key,
algorithms=[JWT_ALGORITHM],
audience=JWT_AUDIENCE,
issuer=JWT_ISSUER
audience=audience,
issuer=JWT_ISSUER,
)
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None


def verify_convex_token(token: str) -> dict | None:
"""Backwards-compatible helper for Convex JWT verification."""
return verify_token(token, audience=JWT_CONVEX_AUDIENCE)


def verify_mobile_access_token(token: str) -> dict | None:
"""Verify a mobile access token."""
return verify_token(token, audience=JWT_MOBILE_AUDIENCE)
50 changes: 50 additions & 0 deletions auth/mobile_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Bearer-token middleware for mobile API routes.
"""
from functools import wraps

from flask import jsonify, g, request

from auth.jwt_utils import verify_mobile_access_token
from db.users import get_user_by_id


def _extract_bearer_token() -> str | None:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[len("Bearer ") :].strip()
return token or None


def mobile_auth_required(func):
"""Require a valid mobile access token and attach user context."""

@wraps(func)
def wrapper(*args, **kwargs):
token = _extract_bearer_token()
if not token:
return jsonify({"error": "authentication_required"}), 401

payload = verify_mobile_access_token(token)
if not payload:
return jsonify({"error": "invalid_token"}), 401

try:
user_id = int(payload.get("sub", ""))
except (TypeError, ValueError):
return jsonify({"error": "invalid_token"}), 401

user_data = get_user_by_id(user_id)
if not user_data:
return jsonify({"error": "invalid_token"}), 401

g.mobile_user = {
"id": user_data["id"],
"device_id": payload.get("device_id"),
}
kwargs["user"] = user_data
kwargs["token_payload"] = payload
return func(*args, **kwargs)

return wrapper
31 changes: 31 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,35 @@ class Config:
# Convex configuration
CONVEX_URL = os.environ.get("CONVEX_URL", "https://hearty-lemur-131.convex.cloud")

# Mobile auth/token configuration
MOBILE_ACCESS_TOKEN_TTL_SECONDS = int(os.environ.get("MOBILE_ACCESS_TOKEN_TTL_SECONDS", "900"))
MOBILE_REFRESH_TOKEN_TTL_DAYS = int(os.environ.get("MOBILE_REFRESH_TOKEN_TTL_DAYS", "30"))
MOBILE_AUTH_CODE_TTL_SECONDS = int(os.environ.get("MOBILE_AUTH_CODE_TTL_SECONDS", "120"))
MOBILE_WEB_TICKET_TTL_SECONDS = int(os.environ.get("MOBILE_WEB_TICKET_TTL_SECONDS", "60"))
MOBILE_STATE_MAX_AGE_SECONDS = int(os.environ.get("MOBILE_STATE_MAX_AGE_SECONDS", "300"))
MOBILE_SCHOOLOGY_REQUEST_TTL_SECONDS = int(
os.environ.get("MOBILE_SCHOOLOGY_REQUEST_TTL_SECONDS", "300")
)
MOBILE_TOKEN_HASH_SECRET = os.environ.get("MOBILE_TOKEN_HASH_SECRET")
MOBILE_ALLOWED_REDIRECT_URIS = [
value.strip()
for value in os.environ.get(
"MOBILE_ALLOWED_REDIRECT_URIS",
"pinewoodone://auth/callback",
).split(",")
if value.strip()
]

# Rate limiting
RATELIMIT_STORAGE_URI = os.environ.get("RATELIMIT_STORAGE_URI", "memory://")

# Mobile banner metadata
BANNER_UPCOMING_IMAGE_URL = os.environ.get("BANNER_UPCOMING_IMAGE_URL")
BANNER_UPCOMING_VERSION = os.environ.get("BANNER_UPCOMING_VERSION", "v1")
BANNER_UPCOMING_CACHE_TTL_SECONDS = int(
os.environ.get("BANNER_UPCOMING_CACHE_TTL_SECONDS", "86400")
)

@classmethod
def validate(cls):
"""Validate configuration and print status"""
Expand All @@ -52,3 +81,5 @@ def validate(cls):
print(f" Consumer Key: {cls.SCHOOLOGY_CONSUMER_KEY[:20]}...")
print(f" Domain: {cls.SCHOOLOGY_DOMAIN}")

if os.environ.get("FLASK_ENV") == "production" and not cls.MOBILE_TOKEN_HASH_SECRET:
raise ValueError("MOBILE_TOKEN_HASH_SECRET is required in production")
Loading
Loading