Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add AI #47

Merged
merged 3 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Configuration settings """

from pydantic_settings import BaseSettings, SettingsConfigDict


Expand Down Expand Up @@ -26,5 +27,9 @@ class Settings(BaseSettings):
# Max number of concurrent compile tasks
max_concurrent_tasks: int = 10

# Groq config
groq_api_key: str
max_llm_tokens: int = 10000


settings = Settings()
1 change: 1 addition & 0 deletions deps/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Code for caching compiled C++ code """

from hashlib import md5

from cachetools import TTLCache
Expand Down
1 change: 1 addition & 0 deletions deps/logs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Create a standard logger """

import logging
from conf import settings

Expand Down
18 changes: 14 additions & 4 deletions deps/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Manage session concurrency """

import uuid
from typing import Annotated

Expand All @@ -8,7 +9,12 @@
from conf import settings

# Hash that stores all known sessions
sessions = TTLCache(maxsize=settings.max_total_sessions, ttl=settings.session_duration)
compile_sessions = TTLCache(
maxsize=settings.max_total_sessions, ttl=settings.session_duration
)
llm_tokens = TTLCache(
maxsize=settings.max_total_sessions, ttl=settings.session_duration
)


def get_session_id(
Expand All @@ -20,10 +26,14 @@ def get_session_id(
session_id = uuid.uuid4().hex
response.set_cookie("session_id", session_id)

if session_id not in sessions:
sessions[session_id] = 0
elif sessions[session_id] >= settings.max_sessions_per_user:
if session_id not in compile_sessions:
compile_sessions[session_id] = 0
elif compile_sessions[session_id] >= settings.max_sessions_per_user:
raise HTTPException(403, "Too many sessions.")

if session_id not in llm_tokens:
llm_tokens[session_id] = 0

return session_id


Expand Down
1 change: 1 addition & 0 deletions deps/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Manage startup / shutdown of the application """

import asyncio
from contextlib import asynccontextmanager

Expand Down
1 change: 1 addition & 0 deletions deps/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Tool to repeatedly call a function, copied from fastapi_utils"""

import asyncio
import logging
from asyncio import ensure_future
Expand Down
31 changes: 24 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Leaphy compiler and minifier backend webservice """

import asyncio
import base64
import tempfile
Expand All @@ -7,15 +8,16 @@
import aiofiles
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from groq import Groq
from python_minifier import minify

from conf import settings
from deps.cache import code_cache, get_code_cache_key, library_cache
from deps.logs import logger
from deps.session import Session, sessions
from deps.session import Session, compile_sessions, llm_tokens
from deps.tasks import startup
from deps.utils import check_for_internet
from models import Sketch, Library, PythonProgram
from models import Sketch, Library, PythonProgram, Messages

app = FastAPI(lifespan=startup)
app.add_middleware(
Expand All @@ -25,7 +27,7 @@
allow_methods=["*"],
allow_headers=["*"],
)

client = Groq(api_key=settings.groq_api_key)

# Limit compiler concurrency to prevent overloading the vm
semaphore = asyncio.Semaphore(settings.max_concurrent_tasks)
Expand Down Expand Up @@ -108,7 +110,7 @@ async def _compile_sketch(sketch: Sketch) -> dict[str, str]:
async def compile_cpp(sketch: Sketch, session_id: Session) -> dict[str, str]:
"""Compile code and return the result in HEX format"""
# Make sure there's no more than X compile requests per user
sessions[session_id] += 1
compile_sessions[session_id] += 1

try:
# Check if this code was compiled before
Expand All @@ -124,14 +126,14 @@ async def compile_cpp(sketch: Sketch, session_id: Session) -> dict[str, str]:
code_cache[cache_key] = result
return result
finally:
sessions[session_id] -= 1
compile_sessions[session_id] -= 1


@app.post("/minify/python")
async def minify_python(program: PythonProgram, session_id: Session) -> PythonProgram:
"""Minify a python program"""
# Make sure there's no more than X minify requests per user
sessions[session_id] += 1
compile_sessions[session_id] += 1
try:
# Check if this code was minified before
try:
Expand All @@ -158,4 +160,19 @@ async def minify_python(program: PythonProgram, session_id: Session) -> PythonPr
code_cache[cache_key] = program
return program
finally:
sessions[session_id] -= 1
compile_sessions[session_id] -= 1


@app.post("/ai/generate")
async def generate(messages: Messages, session_id: Session):
"""Generate message"""
if llm_tokens[session_id] >= settings.max_llm_tokens:
raise HTTPException(429, {"detail": "Try again later"})

response = client.chat.completions.create(
messages=list(map(lambda e: e.dict(), messages.messages)),
model="llama3-70b-8192",
)
llm_tokens[session_id] += response.usage.total_tokens

return response.choices[0].message.content
14 changes: 14 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" FastAPI models """

from typing import Annotated

from pydantic import BaseModel, Field
Expand All @@ -20,3 +21,16 @@ class PythonProgram(BaseModel):

source_code: bytes # Base64 encoded program
filename: str = ""


class Message(BaseModel):
"""Model representing a message"""

role: str
content: str


class Messages(BaseModel):
"""Model representing a collection of messages"""

messages: list[Message]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ cachetools==5.3.3
python-minifier==2.9.0
pyserial==3.5
httpx==0.27.0
groq==0.9.0
Loading