diff --git a/conf.py b/conf.py index 8c1275a..6b2fe98 100644 --- a/conf.py +++ b/conf.py @@ -1,4 +1,5 @@ """ Configuration settings """ + from pydantic_settings import BaseSettings, SettingsConfigDict @@ -26,5 +27,9 @@ class Settings(BaseSettings): # Max number of concurrent compile tasks max_concurrent_tasks: int = 10 + # Groq config + groq_api_key: str + max_llm_tokens: int = 10000 + settings = Settings() diff --git a/deps/cache.py b/deps/cache.py index 1b8bead..d96bcec 100644 --- a/deps/cache.py +++ b/deps/cache.py @@ -1,4 +1,5 @@ """ Code for caching compiled C++ code """ + from hashlib import md5 from cachetools import TTLCache diff --git a/deps/logs.py b/deps/logs.py index 1bcd252..d195808 100644 --- a/deps/logs.py +++ b/deps/logs.py @@ -1,4 +1,5 @@ """ Create a standard logger """ + import logging from conf import settings diff --git a/deps/session.py b/deps/session.py index 7e9cd94..7ca5228 100644 --- a/deps/session.py +++ b/deps/session.py @@ -1,4 +1,5 @@ """ Manage session concurrency """ + import uuid from typing import Annotated @@ -8,7 +9,12 @@ from conf import settings # Hash that stores all known sessions -sessions = TTLCache(maxsize=settings.max_total_sessions, ttl=settings.session_duration) +compile_sessions = TTLCache( + maxsize=settings.max_total_sessions, ttl=settings.session_duration +) +llm_tokens = TTLCache( + maxsize=settings.max_total_sessions, ttl=settings.session_duration +) def get_session_id( @@ -20,10 +26,14 @@ def get_session_id( session_id = uuid.uuid4().hex response.set_cookie("session_id", session_id) - if session_id not in sessions: - sessions[session_id] = 0 - elif sessions[session_id] >= settings.max_sessions_per_user: + if session_id not in compile_sessions: + compile_sessions[session_id] = 0 + elif compile_sessions[session_id] >= settings.max_sessions_per_user: raise HTTPException(403, "Too many sessions.") + + if session_id not in llm_tokens: + llm_tokens[session_id] = 0 + return session_id diff --git a/deps/tasks.py b/deps/tasks.py index c8106e5..aa17fb5 100644 --- a/deps/tasks.py +++ b/deps/tasks.py @@ -1,4 +1,5 @@ """ Manage startup / shutdown of the application """ + import asyncio from contextlib import asynccontextmanager diff --git a/deps/utils.py b/deps/utils.py index 84c8cf5..1f084a5 100644 --- a/deps/utils.py +++ b/deps/utils.py @@ -1,4 +1,5 @@ """ Tool to repeatedly call a function, copied from fastapi_utils""" + import asyncio import logging from asyncio import ensure_future diff --git a/main.py b/main.py index c57fe6c..f97d099 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ """ Leaphy compiler and minifier backend webservice """ + import asyncio import base64 import tempfile @@ -7,15 +8,16 @@ import aiofiles from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware +from groq import Groq from python_minifier import minify from conf import settings from deps.cache import code_cache, get_code_cache_key, library_cache from deps.logs import logger -from deps.session import Session, sessions +from deps.session import Session, compile_sessions, llm_tokens from deps.tasks import startup from deps.utils import check_for_internet -from models import Sketch, Library, PythonProgram +from models import Sketch, Library, PythonProgram, Messages app = FastAPI(lifespan=startup) app.add_middleware( @@ -25,7 +27,7 @@ allow_methods=["*"], allow_headers=["*"], ) - +client = Groq(api_key=settings.groq_api_key) # Limit compiler concurrency to prevent overloading the vm semaphore = asyncio.Semaphore(settings.max_concurrent_tasks) @@ -108,7 +110,7 @@ async def _compile_sketch(sketch: Sketch) -> dict[str, str]: async def compile_cpp(sketch: Sketch, session_id: Session) -> dict[str, str]: """Compile code and return the result in HEX format""" # Make sure there's no more than X compile requests per user - sessions[session_id] += 1 + compile_sessions[session_id] += 1 try: # Check if this code was compiled before @@ -124,14 +126,14 @@ async def compile_cpp(sketch: Sketch, session_id: Session) -> dict[str, str]: code_cache[cache_key] = result return result finally: - sessions[session_id] -= 1 + compile_sessions[session_id] -= 1 @app.post("/minify/python") async def minify_python(program: PythonProgram, session_id: Session) -> PythonProgram: """Minify a python program""" # Make sure there's no more than X minify requests per user - sessions[session_id] += 1 + compile_sessions[session_id] += 1 try: # Check if this code was minified before try: @@ -158,4 +160,19 @@ async def minify_python(program: PythonProgram, session_id: Session) -> PythonPr code_cache[cache_key] = program return program finally: - sessions[session_id] -= 1 + compile_sessions[session_id] -= 1 + + +@app.post("/ai/generate") +async def generate(messages: Messages, session_id: Session): + """Generate message""" + if llm_tokens[session_id] >= settings.max_llm_tokens: + raise HTTPException(429, {"detail": "Try again later"}) + + response = client.chat.completions.create( + messages=list(map(lambda e: e.dict(), messages.messages)), + model="llama3-70b-8192", + ) + llm_tokens[session_id] += response.usage.total_tokens + + return response.choices[0].message.content diff --git a/models.py b/models.py index 1f6347d..a9fd1c7 100644 --- a/models.py +++ b/models.py @@ -1,4 +1,5 @@ """ FastAPI models """ + from typing import Annotated from pydantic import BaseModel, Field @@ -20,3 +21,16 @@ class PythonProgram(BaseModel): source_code: bytes # Base64 encoded program filename: str = "" + + +class Message(BaseModel): + """Model representing a message""" + + role: str + content: str + + +class Messages(BaseModel): + """Model representing a collection of messages""" + + messages: list[Message] diff --git a/requirements.txt b/requirements.txt index d243b74..cd7a027 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ cachetools==5.3.3 python-minifier==2.9.0 pyserial==3.5 httpx==0.27.0 +groq==0.9.0