Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 123 additions & 66 deletions app/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,112 @@
AI functionality for Sugar-AI, including RAG and LLM components.
"""
import os
import torch
from transformers import pipeline
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import logging
import typing
from typing import Optional, List
import app.prompts as prompts
from app.config import settings
import logging

logger = logging.getLogger("sugar-ai")

# Lazy imports for heavy libraries
torch = None
transformers = None
FAISS = None
HuggingFaceEmbeddings = None
PyMuPDFLoader = None
TextLoader = None
RunnablePassthrough = None
ChatPromptTemplate = None

def _lazy_load_deps():
"""Lazily load heavy AI dependencies only when needed"""
global torch, transformers, FAISS, HuggingFaceEmbeddings, PyMuPDFLoader, TextLoader, RunnablePassthrough, ChatPromptTemplate
if torch is None:
import torch as _torch
torch = _torch
if transformers is None:
import transformers as _transformers
transformers = _transformers
if FAISS is None:
from langchain_community.vectorstores import FAISS as _FAISS
FAISS = _FAISS
if HuggingFaceEmbeddings is None:
from langchain_huggingface import HuggingFaceEmbeddings as _HuggingFaceEmbeddings
HuggingFaceEmbeddings = _HuggingFaceEmbeddings
if PyMuPDFLoader is None:
from langchain_community.document_loaders import PyMuPDFLoader as _PyMuPDFLoader
PyMuPDFLoader = _PyMuPDFLoader
if TextLoader is None:
from langchain_community.document_loaders import TextLoader as _TextLoader
TextLoader = _TextLoader
if RunnablePassthrough is None:
from langchain_core.runnables import RunnablePassthrough as _RunnablePassthrough
RunnablePassthrough = _RunnablePassthrough
if ChatPromptTemplate is None:
from langchain_core.prompts import ChatPromptTemplate as _ChatPromptTemplate
ChatPromptTemplate = _ChatPromptTemplate

class ModelManager:
"""Manages model loading and caching to optimize memory usage."""
_models = {}
_tokenizers = {}

@classmethod
def get_model(cls, model_name: str, quantize: bool = True):
"""Load and return model and tokenizer (cached)"""
if settings.DEV_MODE:
logger.info("DEV_MODE active: skipping actual model load in ModelManager")
return None, None

_lazy_load_deps()

if model_name in cls._models:
return cls._models[model_name], cls._tokenizers[model_name]

logger.info(f"Loading model: {model_name}")

# Determine device and dtype
device = 0 if torch.cuda.is_available() else -1
dtype = torch.float16 if device == 0 else torch.float32

if quantize and device == 0:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model_obj = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map="auto"
)
pipe = transformers.pipeline(
"text-generation",
model=model_obj,
tokenizer=tokenizer,
max_new_tokens=1024,
truncation=True,
)
else:
pipe = transformers.pipeline(
"text-generation",
model=model_name,
max_new_tokens=1024,
truncation=True,
torch_dtype=dtype,
device=device,
)
tokenizer = pipe.tokenizer

cls._models[model_name] = pipe
cls._tokenizers[model_name] = tokenizer
return pipe, tokenizer

def format_docs(docs):
"""Return document content separated by newlines"""
return "\n\n".join(doc.page_content for doc in docs)
Expand Down Expand Up @@ -50,76 +142,41 @@ class RAGAgent:
"""Retrieval-Augmented Generation agent for Sugar-AI"""

def __init__(self, model: Optional[str] = None, quantize: bool = True):
# 1) Determine model name with clear precedence:
# explicit argument > DEV_MODEL_NAME (if DEV_MODE) > PROD_MODEL_NAME > DEFAULT_MODEL
self.model = None
self.simplify_model = None
self.retriever: Optional[typing.Any] = None
self.quantize = quantize

# Determine model name with clear precedence
if model:
self.model_name = model
logger.info("Using explicit model argument: %s", self.model_name)
else:
if getattr(settings, "DEV_MODE", False):
# prefer DEV_MODEL_NAME, then fallback to DEFAULT_MODEL
if settings.DEV_MODE:
self.model_name = getattr(settings, "DEV_MODEL_NAME", settings.DEFAULT_MODEL)
logger.info("DEV_MODE active: using lightweight model %s", self.model_name)
logger.info("DEV_MODE active: using placeholder for agent model")
else:
# production: prefer PROD_MODEL_NAME, else DEFAULT_MODEL
self.model_name = getattr(settings, "PROD_MODEL_NAME", settings.DEFAULT_MODEL)
logger.info("Using production model %s", self.model_name)

# 2) Compute quantization/device choices. Keep quantization off in DEV_MODE by default.
self.use_quant = quantize and torch.cuda.is_available() and not getattr(settings, "DEV_MODE", False)
device = 0 if torch.cuda.is_available() and not getattr(settings, "DEV_MODE", False) else -1
dtype = torch.float16 if device == 0 else torch.float32

if self.use_quant:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model_obj = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map="auto"
)
self.model = pipeline(
"text-generation",
model=model_obj,
tokenizer=tokenizer,
max_new_tokens=1024,
truncation=True,
)

self.simplify_model = pipeline(
"text-generation",
model=model_obj,
tokenizer=tokenizer,
max_new_tokens=1024,
truncation=True,
)
else:
self.model = pipeline(
"text-generation",
model=self.model_name,
max_new_tokens=1024,
truncation=True,
torch_dtype=dtype, # Use the dynamic dtype
device=device, # Use the dynamic device
)

self.simplify_model = self.model

self.retriever: Optional[FAISS] = None

_lazy_load_deps()
self.prompt = ChatPromptTemplate.from_template(prompts.PROMPT_TEMPLATE)
self.child_prompt = ChatPromptTemplate.from_template(prompts.CHILD_FRIENDLY_PROMPT)
self.debug_prompt = ChatPromptTemplate.from_template(prompts.CODE_DEBUG_PROMPT)
self.context_prompt = ChatPromptTemplate.from_template(prompts.CODE_CONTEXT_PROMPT)
self.kids_debug_prompt = ChatPromptTemplate.from_template(prompts.KIDS_DEBUG_PROMPT)
self.kids_context_prompt = ChatPromptTemplate.from_template(prompts.KIDS_CONTEXT_PROMPT)

def ensure_model_loaded(self):
"""Ensure the underlying models are loaded (skipped in DEV_MODE)"""
if settings.DEV_MODE:
logger.info("DEV_MODE active: skipping actual model loading in ensure_model_loaded")
return

if self.model is None:
logger.info(f"Lazily loading model for RAGAgent: {self.model_name}")
self.model, _ = ModelManager.get_model(self.model_name, quantize=self.quantize)
self.simplify_model = self.model


def set_model(self, model: str) -> None:
"""Update the model used by the agent"""
self.model_name = model
Expand Down
16 changes: 11 additions & 5 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import os
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
from typing import Dict, List, Any, Optional
from typing import Dict, List, Any, Optional, Union

class Settings(BaseSettings):
"""Application settings loaded from environment variables"""

# Dev mode (THIS MUST EXIST)
DEV_MODE: bool = os.getenv("DEV_MODE", "0") == "1"
DEV_MODEL_NAME: str | None = None
PROD_MODEL_NAME: str | None = None
DEFAULT_MODEL: str | None = None
DEV_MODEL_NAME: Optional[str] = None
PROD_MODEL_NAME: Optional[str] = None
DEFAULT_MODEL: Optional[str] = None
AVAILABLE_MODELS: Optional[str] = None

API_KEYS: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
MODEL_CHANGE_PASSWORD: str = ""
Expand All @@ -37,4 +38,9 @@ class Config:
env_file = ".env"
extra = "allow" # this allows extra attribute if we have any

settings = Settings()
settings = Settings()
# Parse AVAILABLE_MODELS into a list
if settings.AVAILABLE_MODELS:
settings.AVAILABLE_MODELS = [m.strip() for m in settings.AVAILABLE_MODELS.split(",") if m.strip()]
else:
settings.AVAILABLE_MODELS = []
50 changes: 29 additions & 21 deletions app/routes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class PromptedLLMRequest(BaseModel):
top_k: int = Field(50, description="Top-k sampling parameter")

router = APIRouter(tags=["api"])
@router.get("/list-models")
def list_models():
"""
Returns a list of available models.
"""
return {"models": settings.AVAILABLE_MODELS}

# setup logging
logger = logging.getLogger("sugar-ai")
Expand Down Expand Up @@ -92,13 +98,10 @@ async def ask_question(
logger.info(f"REQUEST - /ask - User: {user_info['name']} - IP: {client_ip} - Question: {question[:50]}...")

try:
agent.ensure_model_loaded()
answer = agent.run(question)

# log completion
process_time = time.time() - start_time
logger.info(f"RESPONSE - User: {user_info['name']} - Success - Time: {process_time:.2f}s")

# check quota
api_key = next(
key for key, value in settings.API_KEYS.items()
if value['name'] == user_info['name']
Expand All @@ -107,9 +110,8 @@ async def ask_question(
settings.MAX_DAILY_REQUESTS
- user_quotas.get(api_key, {}).get("count", 0)
)

return {
"answer": answer,
"answer": answer,
"user": user_info["name"],
"quota": {"remaining": remaining, "total": settings.MAX_DAILY_REQUESTS}
}
Expand All @@ -130,18 +132,15 @@ async def ask_llm(
logger.info(f"REQUEST - /ask-llm - User: {user_info['name']} - IP: {client_ip} - Question: {question[:50]}...")

try:
agent.ensure_model_loaded()
response = agent.model(question)
answer = extract_answer_from_output(response)

process_time = time.time() - start_time
logger.info(f"RESPONSE - User: {user_info['name']} - Success - Time: {process_time:.2f}s")

# check quota
api_key = next(key for key, value in settings.API_KEYS.items() if value['name'] == user_info['name'])
remaining = settings.MAX_DAILY_REQUESTS - user_quotas.get(api_key, {}).get("count", 0)

return {
"answer": answer,
"answer": answer,
"user": user_info["name"],
"quota": {"remaining": remaining, "total": settings.MAX_DAILY_REQUESTS}
}
Expand Down Expand Up @@ -260,33 +259,34 @@ async def ask_llm_prompted(
except Exception as e:
logger.error(f"ERROR - User: {user_info['name']} - Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

@router.post("/debug")
async def debug(
code: str,
context: bool,
context: bool = False,
user_info: dict = Depends(verify_api_key),
request: Request = None
):
"""Process python code for debugging"""
"""Process python code for debugging using lightweight sandboxed execution"""
start_time = time.time()

client_ip = request.client.host if request else "unknown"
logger.info(f"REQUEST - /debug - User: {user_info['name']} - IP: {client_ip} - code: {code[:50]}...")

try:
response = agent.debug(code, context)
answer = response
from app.routes.debug import debug_code
result = debug_code(code)

process_time = time.time() - start_time
logger.info(f"RESPONSE - User: {user_info['name']} - Success - Time: {process_time:.2f}s")
logger.info(f"RESPONSE - /debug - User: {user_info['name']} - Success - Time: {process_time:.2f}s")

# check quota
api_key = next(key for key, value in settings.API_KEYS.items() if value['name'] == user_info['name'])
remaining = settings.MAX_DAILY_REQUESTS - user_quotas.get(api_key, {}).get("count", 0)

return {
"answer": answer,
"answer": result["answer"],
"status": result.get("status"),
"user": user_info["name"],
"quota": {"remaining": remaining, "total": settings.MAX_DAILY_REQUESTS}
}
Expand Down Expand Up @@ -319,10 +319,18 @@ async def change_model(
logger.warning(f"Invalid password for model change by: {user_info['name']} from {client_ip}")
raise HTTPException(status_code=403, detail="Invalid model change password")

# Validate model name
if model not in settings.AVAILABLE_MODELS:
logger.warning(f"Invalid model name '{model}' attempted by {user_info['name']} from {client_ip}")
return {"success": False, "error": "Invalid model name", "available_models": settings.AVAILABLE_MODELS}

try:
agent.set_model(model)
from app.ai import ModelManager
model_obj, tokenizer = ModelManager.get_model(model)
agent.model = model_obj
agent.tokenizer = tokenizer
logger.info(f"Model changed to {model} by {user_info['name']}")
return {"message": f"Model changed to {model}", "user": user_info["name"]}
return {"success": True, "message": f"Model changed to {model}", "user": user_info["name"]}
except Exception as e:
logger.error(f"Error changing model to {model} by {user_info['name']}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error changing model: {str(e)}")
return {"success": False, "error": f"Error changing model: {str(e)}"}
Loading