Skip to content

Commit

Permalink
rag: adding ability to query NV-Ingest vectordb via argument
Browse files Browse the repository at this point in the history
* init

* skeleton

* add job id pass to pdf service

* add vdb

* fix ci

* adding job id to query rag endpoint

* stupid comment remove

* uptake jeremy work

* update test to check for vdb

* working rag

* addressing comments

* ruff
  • Loading branch information
ishandhanani authored Nov 19, 2024
1 parent 71cecc1 commit fe85e77
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 12 deletions.
68 changes: 64 additions & 4 deletions services/APIService/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
WebSocketDisconnect,
Query,
)
from shared.api_types import ServiceType, JobStatus, StatusUpdate, TranscriptionParams
from shared.api_types import (
ServiceType,
JobStatus,
StatusUpdate,
TranscriptionParams,
RAGRequest,
)
from shared.prompt_types import PromptTracker
from shared.podcast_types import SavedPodcast, SavedPodcastWithAudio, Conversation
from shared.connection import ConnectionManager
Expand All @@ -21,6 +27,7 @@
from pydantic import ValidationError
import redis
import requests
import httpx
import ujson as json
import uuid
import os
Expand All @@ -33,7 +40,13 @@
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(debug=True)
app = FastAPI(
debug=True,
title="AI Research Assistant API Service",
description="API Service for the AI Research Assistant project",
docs_url="/docs",
redoc_url="/redoc",
)

# Initialize OpenTelemetry
telemetry = OpenTelemetryInstrumentation()
Expand Down Expand Up @@ -62,6 +75,10 @@
# MP3 Cache TTL
MP3_CACHE_TTL = 60 * 60 * 4 # 4 hours

# NV-Ingest
DEFAULT_TIMEOUT = 600 # seconds
NV_INGEST_RETRIEVE_URL = "https://nv-ingest-rest-endpoint.brevlab.com/v1"

# CORS setup
CORS_ORIGINS = os.getenv(
"CORS_ORIGINS",
Expand Down Expand Up @@ -145,7 +162,9 @@ async def websocket_endpoint(websocket: WebSocket, job_id: str):


def process_pdf_task(
job_id: str, files_content: List[bytes], transcription_params: TranscriptionParams
job_id: str,
files_content: List[bytes],
transcription_params: TranscriptionParams,
):
with telemetry.tracer.start_as_current_span("api.process_pdf_task") as span:
span.set_attribute("job_id", job_id)
Expand Down Expand Up @@ -173,8 +192,13 @@ def process_pdf_task(
for i, content in enumerate(files_content)
]

logger.info(
f"Sending {len(files)} PDFs to PDF Service for {job_id} with VDB task: {transcription_params.vdb_task}"
)
requests.post(
f"{PDF_SERVICE_URL}/convert", files=files, data={"job_id": job_id}
f"{PDF_SERVICE_URL}/convert",
files=files,
data={"job_id": job_id, "vdb_task": transcription_params.vdb_task},
)

# Monitor services
Expand Down Expand Up @@ -643,6 +667,42 @@ async def delete_saved_podcast(
)


@app.post("/query_vector_db")
async def query_vector_db(
payload: RAGRequest,
):
"""RAG endpoint that interfaces with NV-Ingest to retrieve top k results"""
with telemetry.tracer.start_as_current_span("api.query_vector_db") as span:
span.set_attribute("job_id", payload.job_id)
span.set_attribute("k", payload.k)

async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
try:
response = await client.post(
f"{NV_INGEST_RETRIEVE_URL}/query",
json={
"query": payload.query,
"k": payload.k,
"job_id": payload.job_id,
},
)
if response.status_code != 200:
span.set_status(
StatusCode.ERROR, "failed to retrieve from NV-Ingest"
)
raise HTTPException(
status_code=response.status_code,
detail=f"NV-Ingest error: {response.text}",
)
return response.json()
except Exception as e:
span.set_status(StatusCode.ERROR, "failed to retrieve from NV-Ingest")
raise HTTPException(
status_code=500,
detail=f"Failed to retrieve from NV-Ingest: {str(e)}",
)


@app.get("/health")
async def health():
"""Health check endpoint with OpenTelemetry instrumentation"""
Expand Down
17 changes: 12 additions & 5 deletions services/PDFService/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
DEFAULT_TIMEOUT = 600 # seconds


async def convert_pdfs_to_markdown(pdf_paths: List[str]) -> List[PDFConversionResult]:
async def convert_pdfs_to_markdown(
pdf_paths: List[str], job_id: str, vdb_task: bool = False
) -> List[PDFConversionResult]:
"""Convert multiple PDFs to Markdown using the external API service"""
logger.info(f"Sending {len(pdf_paths)} PDFs to external conversion service")
with telemetry.tracer.start_as_current_span("pdf.convert_pdfs_to_markdown") as span:
Expand All @@ -58,7 +60,9 @@ async def convert_pdfs_to_markdown(pdf_paths: List[str]) -> List[PDFConversionRe
span.set_attribute("model_api_url", MODEL_API_URL)
logger.info(f"Sending {len(files)} files to model API")
response = await client.post(
f"{MODEL_API_URL}/convert", files=files
f"{MODEL_API_URL}/convert",
files=files,
data={"job_id": job_id, "vdb_task": vdb_task},
)
finally:
# Clean up file handles after request is complete
Expand Down Expand Up @@ -149,7 +153,9 @@ async def convert_pdfs_to_markdown(pdf_paths: List[str]) -> List[PDFConversionRe
)


async def process_pdfs(job_id: str, contents: List[bytes], filenames: List[str]):
async def process_pdfs(
job_id: str, contents: List[bytes], filenames: List[str], vdb_task: bool = False
):
"""Process multiple PDFs and return metadata for each"""
with telemetry.tracer.start_as_current_span("pdf.process_pdfs") as span:
try:
Expand Down Expand Up @@ -183,7 +189,7 @@ async def process_pdfs(job_id: str, contents: List[bytes], filenames: List[str])
f"Starting PDF to Markdown conversion for {len(temp_files)} files"
)
# Convert all PDFs in a single batch
results = await convert_pdfs_to_markdown(temp_files)
results = await convert_pdfs_to_markdown(temp_files, job_id, vdb_task)
logger.info(f"Conversion completed, processing {len(results)} results")

# Create metadata list
Expand Down Expand Up @@ -252,6 +258,7 @@ async def convert_pdf(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),
job_id: str = Form(...),
vdb_task: bool = Form(False),
):
"""Convert multiple PDFs to Markdown"""
with telemetry.tracer.start_as_current_span("pdf.convert_pdf") as span:
Expand All @@ -274,7 +281,7 @@ async def convert_pdf(
job_manager.create_job(job_id)

# Start processing in background
background_tasks.add_task(process_pdfs, job_id, contents, filenames)
background_tasks.add_task(process_pdfs, job_id, contents, filenames, vdb_task)

return {"job_id": job_id}

Expand Down
10 changes: 10 additions & 0 deletions shared/shared/api_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class TranscriptionParams(BaseModel):
guide: Optional[str] = Field(
None, description="Optional guidance for the transcription focus and structure"
)
vdb_task: bool = Field(
False,
description="If True, creates a VDB task when running NV-Ingest allowing for retrieval abilities",
)

@model_validator(mode="after")
def validate_monologue_settings(self) -> "TranscriptionParams":
Expand Down Expand Up @@ -93,3 +97,9 @@ def validate_monologue_settings(self) -> "TranscriptionParams":
class TranscriptionRequest(TranscriptionParams):
pdf_metadata: List[PDFMetadata]
job_id: str


class RAGRequest(BaseModel):
query: str = Field(..., description="The search query to process")
k: int = Field(..., description="Number of results to retrieve", ge=1)
job_id: str = Field(..., description="The unique job identifier")
30 changes: 27 additions & 3 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def test_saved_podcasts(base_url: str, job_id: str, max_retries=5, retry_delay=5
print(f"Successfully retrieved audio data, size: {len(audio_data)} bytes")


def test_api(base_url: str, pdf_files: List[str], monologue: bool = False):
def test_api(
base_url: str, pdf_files: List[str], monologue: bool = False, vdb: bool = False
):
voice_mapping = {
"speaker-1": "iP95p4xoKVk53GoZ742B",
}
Expand Down Expand Up @@ -250,7 +252,8 @@ def test_api(base_url: str, pdf_files: List[str], monologue: bool = False):
"voice_mapping": voice_mapping,
"guide": None,
"monologue": monologue,
"userId": TEST_USER_ID, # Add userId to transcription params
"userId": TEST_USER_ID,
"vdb_task": vdb,
}

if not monologue:
Expand Down Expand Up @@ -316,6 +319,21 @@ def test_api(base_url: str, pdf_files: List[str], monologue: bool = False):
# Test saved podcasts endpoints with the newly created job_id
test_saved_podcasts(base_url, job_id)

# Test RAG endpoint if vdb flag is enabled
if vdb:
print("\nTesting RAG endpoint...")
test_query = "What is the main topic of this document?"
rag_response = requests.post(
f"{base_url}/query_vector_db",
json={"query": test_query, "k": 3, "job_id": job_id},
)
assert (
rag_response.status_code == 200
), f"RAG endpoint failed: {rag_response.text}"
rag_results = rag_response.json()
print(f"RAG Query: '{test_query}'")
print(f"RAG Results: {json.dumps(rag_results, indent=2)}")

finally:
monitor.stop()

Expand All @@ -335,11 +353,17 @@ def test_api(base_url: str, pdf_files: List[str], monologue: bool = False):
action="store_true",
help="Generate a monologue instead of a dialogue",
)
parser.add_argument(
"--vdb",
action="store_true",
help="Enable Vector Database processing",
)

args = parser.parse_args()
print(f"API URL: {args.api_url}")
print(f"Processing PDF files: {args.pdf_files}")
print(f"Monologue mode: {args.monologue}")
print(f"VDB mode: {args.vdb}")
print(f"Using test user ID: {TEST_USER_ID}")

test_api(args.api_url, args.pdf_files, args.monologue)
test_api(args.api_url, args.pdf_files, args.monologue, args.vdb)

0 comments on commit fe85e77

Please sign in to comment.