Skip to content

Add endpoints for getting exams from the exam bank #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
60 changes: 60 additions & 0 deletions src/alembic/versions/2f1b67c68ba5_add_exam_bank_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""add exam bank tables

Revision ID: 2f1b67c68ba5
Revises: 3f19883760ae
Create Date: 2025-01-03 00:24:44.608869

"""
from collections.abc import Sequence

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "2f1b67c68ba5"
down_revision: str | None = "3f19883760ae"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
op.create_table(
"professor",
sa.Column("professor_id", sa.Integer, primary_key=True, autoincrement=True),
sa.Column("name", sa.String(128), nullable=False),
sa.Column("info_url", sa.String(128), nullable=False),
sa.Column("computing_id", sa.String(32), sa.ForeignKey("user_session.computing_id"), nullable=True),
)

op.create_table(
"course",
sa.Column("course_id", sa.Integer, primary_key=True, autoincrement=True),
sa.Column("course_faculty", sa.String(12), nullable=False),
sa.Column("course_number", sa.String(12), nullable=False),
sa.Column("course_name", sa.String(96), nullable=False),
)

op.create_table(
"exam_metadata",
sa.Column("exam_id", sa.Integer, primary_key=True),
sa.Column("upload_date", sa.DateTime, nullable=False),
sa.Column("exam_pdf_size", sa.Integer, nullable=False),

sa.Column("author_id", sa.String(32), sa.ForeignKey("professor.professor_id"), nullable=False),
sa.Column("author_confirmed", sa.Boolean, nullable=False),
sa.Column("author_permission", sa.Boolean, nullable=False),

sa.Column("kind", sa.String(24), nullable=False),
sa.Column("course_id", sa.String(32), sa.ForeignKey("professor.professor_id"), nullable=True),
sa.Column("title", sa.String(96), nullable=True),
sa.Column("description", sa.Text, nullable=True),

sa.Column("date_string", sa.String(10), nullable=False),
)


def downgrade() -> None:
op.drop_table("exam_metadata")
op.drop_table("professor")
op.drop_table("course")
25 changes: 25 additions & 0 deletions src/alembic/versions/3f19883760ae_add_session_type_to_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""add session_type to auth

Revision ID: 3f19883760ae
Revises: 2a6ea95342dc
Create Date: 2025-01-03 00:16:50.579541

"""
from collections.abc import Sequence

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "3f19883760ae"
down_revision: str | None = "2a6ea95342dc"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
op.add_column("user_session", sa.Column("session_type", sa.String(48), nullable=False))

def downgrade() -> None:
op.drop_column("user_session", "session_type")
52 changes: 36 additions & 16 deletions src/auth/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

_logger = logging.getLogger(__name__)

async def create_user_session(db_session: AsyncSession, session_id: str, computing_id: str):
async def create_user_session(
db_session: AsyncSession,
session_id: str,
computing_id: str,
session_type: str,
) -> None:
"""
Updates the past user session if one exists, so no duplicate sessions can ever occur.

Expand Down Expand Up @@ -46,50 +51,67 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi
existing_user.last_logged_in = datetime.now()
else:
db_session.add(UserSession(
session_id=session_id,
computing_id=computing_id,
issue_time=datetime.now(),
session_id=session_id,
session_type=session_type,
))


async def remove_user_session(db_session: AsyncSession, session_id: str) -> dict:
query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id)
user_session = await db_session.scalars(query)
user_session = await db_session.scalars(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
await db_session.delete(user_session.first())


async def get_computing_id(db_session: AsyncSession, session_id: str) -> str | None:
query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id)
existing_user_session = (await db_session.scalars(query)).first()
existing_user_session = await db_session.scalar(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
return existing_user_session.computing_id if existing_user_session else None


async def get_session_type(db_session: AsyncSession, session_id: str) -> str | None:
existing_user_session = await db_session.scalar(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
return existing_user_session.session_type if existing_user_session else None


# remove all out of date user sessions
async def task_clean_expired_user_sessions(db_session: AsyncSession):
one_day_ago = datetime.now() - timedelta(days=0.5)

query = sqlalchemy.delete(UserSession).where(UserSession.issue_time < one_day_ago)
await db_session.execute(query)
await db_session.execute(
sqlalchemy
.delete(UserSession)
.where(UserSession.issue_time < one_day_ago)
)
await db_session.commit()


# get the site user given a session ID; returns None when session is invalid
async def get_site_user(db_session: AsyncSession, session_id: str) -> None | SiteUserData:
query = (
user_session = await db_session.scalar(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
user_session = await db_session.scalar(query)
if user_session is None:
return None

query = (
user = await db_session.scalar(
sqlalchemy
.select(SiteUser)
.where(SiteUser.computing_id == user_session.computing_id)
)
user = await db_session.scalar(query)
if user is None:
return None

Expand All @@ -116,21 +138,19 @@ async def update_site_user(
session_id: str,
profile_picture_url: str
) -> bool:
query = (
user_session = await db_session.scalar(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
user_session = await db_session.scalar(query)
if user_session is None:
return False

query = (
await db_session.execute(
sqlalchemy
.update(SiteUser)
.where(SiteUser.computing_id == user_session.computing_id)
.values(profile_picture_url = profile_picture_url)
)
await db_session.execute(query)

return True
4 changes: 3 additions & 1 deletion src/auth/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy import Column, DateTime, ForeignKey, String, Text

from constants import COMPUTING_ID_LEN, SESSION_ID_LEN
from constants import COMPUTING_ID_LEN, SESSION_ID_LEN, SESSION_TYPE_LEN
from database import Base


Expand All @@ -23,6 +23,8 @@ class UserSession(Base):
String(SESSION_ID_LEN), nullable=False, unique=True
) # the space needed to store 256 bytes in base64

# whether a user is faculty, csss-member, student, or just "sfu"
session_type = Column(String(SESSION_TYPE_LEN), nullable=False)

class SiteUser(Base):
# user is a reserved word in postgres
Expand Down
20 changes: 20 additions & 0 deletions src/auth/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
from dataclasses import dataclass


class SessionType:
# see: https://www.sfu.ca/information-systems/services/cas/cas-for-web-applications/
# for more info on the kinds of members
FACULTY = "faculty"
# TODO: what will happen to the maillists for authentication; are groups part of this?
CSSS_MEMBER = "csss member" # !cs-students maillist
STUDENT = "student"
ALUMNI = "alumni"
SFU = "sfu"

@staticmethod
def valid_session_type_list():
# values taken from https://www.sfu.ca/information-systems/services/cas/cas-for-web-applications.html
return [
"faculty",
"student",
"alumni",
"sfu"
]

@dataclass
class SiteUserData:
computing_id: str
Expand Down
38 changes: 32 additions & 6 deletions src/auth/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import database
from auth import crud
from auth.types import SessionType
from constants import FRONTEND_ROOT_URL

_logger = logging.getLogger(__name__)
Expand All @@ -31,7 +32,6 @@ def generate_session_id_b64(num_bytes: int) -> str:
tags=["authentication"],
)


# NOTE: logging in a second time invaldiates the last session_id
@router.get(
"/login",
Expand All @@ -47,17 +47,43 @@ async def login_user(
# verify the ticket is valid
service = urllib.parse.quote(f"{FRONTEND_ROOT_URL}/api/auth/login?redirect_path={redirect_path}&redirect_fragment={redirect_fragment}")
service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={service}&ticket={ticket}"
cas_response = xmltodict.parse(requests.get(service_validate_url).text)
cas_response_text = requests.get(service_validate_url).text
cas_response = xmltodict.parse(cas_response_text)

print("CAS RESPONSE ::")
print(cas_response_text)

if "cas:authenticationFailure" in cas_response["cas:serviceResponse"]:
_logger.info(f"User failed to login, with response {cas_response}")
raise HTTPException(status_code=401, detail="authentication error, ticket likely invalid")

else:
elif "cas:authenticationSuccess" in cas_response["cas:serviceResponse"]:
session_id = generate_session_id_b64(256)
computing_id = cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:user"]

await crud.create_user_session(db_session, session_id, computing_id)
# NOTE: it is the frontend's job to pass the correct authentication reuqest to CAS, otherwise we
# will only be able to give a user the "sfu" session_type (least privileged)
if "cas:maillist" in cas_response["cas:serviceResponse"]:
# maillist
# TODO: (ASK SFU IT) can alumni be in the cmpt-students maillist?
if cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:maillist"] == "cmpt-students":
session_type = SessionType.CSSS_MEMBER
else:
raise HTTPException(status_code=500, details="malformed cas:maillist authentication response; this is an SFU CAS error")
elif "cas:authtype" in cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]:
# sfu, alumni, faculty, student
session_type = cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:authtype"]
if session_type not in SessionType.valid_session_type_list():
raise HTTPException(status_code=500, detail=f"unexpected session type from SFU CAS of {session_type}")

if session_type == "alumni":
if "@" not in computing_id:
raise HTTPException(status_code=500, detail=f"invalid alumni computing_id response from CAS AUTH with value {session_type}")
computing_id = computing_id.split("@")[0]
else:
raise HTTPException(status_code=500, detail="malformed unknown authentication response; this is an SFU CAS error")

await crud.create_user_session(db_session, session_id, computing_id, session_type)
await db_session.commit()

# clean old sessions after sending the response
Expand All @@ -69,6 +95,8 @@ async def login_user(
) # this overwrites any past, possibly invalid, session_id
return response

else:
raise HTTPException(status_code=500, detail="malformed authentication response; this is an SFU CAS error")

@router.get(
"/logout",
Expand All @@ -91,7 +119,6 @@ async def logout_user(
response.delete_cookie(key="session_id")
return response


@router.get(
"/user",
description="Get info about the current user. Only accessible by that user",
Expand All @@ -113,7 +140,6 @@ async def get_user(

return JSONResponse(user_info.serializable_dict())


@router.patch(
"/user",
description="Update information for the currently logged in user. Only accessible by that user",
Expand Down
20 changes: 20 additions & 0 deletions src/auth/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from fastapi import HTTPException, Request

import auth.crud
import database


async def logged_in_or_raise(
request: Request,
db_session: database.DBSession
) -> tuple[str, str]:
"""gets the user's computing_id, or raises an exception if the current request is not logged in"""
session_id = request.cookies.get("session_id", None)
if session_id is None:
raise HTTPException(status_code=401)

session_computing_id = await auth.crud.get_computing_id(db_session, session_id)
if session_computing_id is None:
raise HTTPException(status_code=401)

return session_id, session_computing_id
3 changes: 3 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
COMPUTING_ID_LEN = 32
COMPUTING_ID_MAX = 8

# depends how large SFU maillists can be
SESSION_TYPE_LEN = 48

# see https://support.discord.com/hc/en-us/articles/4407571667351-How-to-Find-User-IDs-for-Law-Enforcement#:~:text=Each%20Discord%20user%20is%20assigned,user%20and%20cannot%20be%20changed.
# NOTE: the length got updated to 19 in july 2024. See https://www.reddit.com/r/discordapp/comments/ucrp1r/only_3_months_until_discord_ids_hit_19_digits/
# I set us to 32 just in case...
Expand Down
Loading
Loading