From 5df87adcc98c62b6f1a62d37b1001d4940d07a7d Mon Sep 17 00:00:00 2001 From: Andrew Wang Date: Tue, 19 Nov 2024 17:23:46 -0800 Subject: [PATCH] update test.py --- services/APIService/main.py | 41 +++++++----- services/AgentService/test_api.py | 5 +- services/PDFService/main.py | 7 +- shared/shared/pdf_types.py | 13 +--- tests/prod-test.sh | 2 +- tests/test.py | 107 +++++++++++++++++++++--------- 6 files changed, 116 insertions(+), 59 deletions(-) diff --git a/services/APIService/main.py b/services/APIService/main.py index 57faa94..6d445b9 100644 --- a/services/APIService/main.py +++ b/services/APIService/main.py @@ -19,7 +19,7 @@ ) from shared.prompt_types import PromptTracker from shared.podcast_types import SavedPodcast, SavedPodcastWithAudio, Conversation -from shared.pdf_types import PDFFileUpload, FileContentTuple +from shared.pdf_types import FileContentTuple from shared.connection import ConnectionManager from shared.storage import StorageManager from shared.otel import OpenTelemetryInstrumentation, OpenTelemetryConfig @@ -35,7 +35,7 @@ import logging import time import asyncio -from typing import Dict, List, Union, Annotated +from typing import Dict, List logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -98,6 +98,7 @@ ) logger.info(f"CORS configured with allowed origins: {allowed_origins}") + @app.websocket("/ws/status/{job_id}") async def websocket_endpoint(websocket: WebSocket, job_id: str): try: @@ -182,9 +183,12 @@ def process_pdf_task( "application/pdf", transcription_params, ) - logger.info( - f"Stored {len(files)} original PDFs for {job_id} in storage" - ) + logger.info(f"Stored {len(files)} original PDFs for {job_id} in storage") + files_for_request = [] + for idx, (content, file_type) in enumerate(files): + files_for_request.append( + ("files", (f"file_{idx}.pdf", content, "application/pdf")) + ) logger.info( f"Sending {len(files)} PDFs to PDF Service for {job_id} with VDB task: {transcription_params.vdb_task}" ) @@ -289,30 +293,35 @@ def process_pdf_task( @app.post("/process_pdf", status_code=202) async def process_pdf( background_tasks: BackgroundTasks, - files: Annotated[Union[PDFFileUpload, List[PDFFileUpload]], File(...)], + files: List[UploadFile] = File(...), + file_types: List[str] = Form(...), transcription_params: str = Form(...), ): with telemetry.tracer.start_as_current_span("api.process_pdf") as span: - # Convert single file to list for consistent handling - files = [files] if isinstance(files, PDFFileUpload) else files + if len(files) != len(file_types): + raise HTTPException( + status_code=400, + detail="Number of files must match number of file types", + ) span.set_attribute("request", transcription_params) span.set_attribute("num_files", len(files)) - if len(files) == 1 and files[0].type != "target": + + if len(files) == 1 and file_types[0] != "target": raise HTTPException( - status_code=400, - detail="Single file must be designated as 'target'" + status_code=400, detail="Single file must be designated as 'target'" ) # Ensure at least one target file - if not any(f.type == "target" for f in files): + if not any(ft == "target" for ft in file_types): raise HTTPException( status_code=400, - detail="At least one file must be designated as 'target'" + detail="At least one file must be designated as 'target'", ) + # Validate all files are PDFs for file in files: - if file.file.content_type != "application/pdf": + if file.content_type != "application/pdf": span.set_status( status=StatusCode.ERROR, description="invalid file type" ) @@ -334,8 +343,8 @@ async def process_pdf( # Read all files files_data: List[FileContentTuple] = [] - for file_upload, file_type in files: - content = await file_upload.file.read() + for file, file_type in zip(files, file_types): + content = await file.read() files_data.append((content, file_type)) # Start processing diff --git a/services/AgentService/test_api.py b/services/AgentService/test_api.py index 2d66a4b..ba2ce0b 100644 --- a/services/AgentService/test_api.py +++ b/services/AgentService/test_api.py @@ -13,7 +13,10 @@ def test_transcribe_api(): # Create a proper TranscriptionRequest pdf_metadata_1 = PDFMetadata( - filename="sample.pdf", markdown="Sample markdown content", summary="", type="target" + filename="sample.pdf", + markdown="Sample markdown content", + summary="", + type="target", ) pdf_metadata_2 = PDFMetadata( diff --git a/services/PDFService/main.py b/services/PDFService/main.py index 7884ebc..be070d0 100644 --- a/services/PDFService/main.py +++ b/services/PDFService/main.py @@ -9,7 +9,12 @@ import asyncio import ujson as json from typing import List -from shared.pdf_types import PDFConversionResult, ConversionStatus, PDFMetadata, FileContentTuple +from shared.pdf_types import ( + PDFConversionResult, + ConversionStatus, + PDFMetadata, + FileContentTuple, +) from shared.api_types import ServiceType, JobStatus, StatusResponse logging.basicConfig(level=logging.INFO) diff --git a/shared/shared/pdf_types.py b/shared/shared/pdf_types.py index d314e48..facb15d 100644 --- a/shared/shared/pdf_types.py +++ b/shared/shared/pdf_types.py @@ -1,9 +1,10 @@ -from fastapi import UploadFile, Form, File +from fastapi import UploadFile from pydantic import BaseModel, Field from typing import Optional, Union, Literal, Tuple from datetime import datetime from enum import Enum + class ConversionStatus(str, Enum): SUCCESS = "success" FAILED = "failed" @@ -25,14 +26,6 @@ class PDFMetadata(BaseModel): error: Optional[str] = None created_at: datetime = Field(default_factory=datetime.utcnow) -class PDFFileUpload: - def __init__( - self, - file: UploadFile = File(...), - type: Literal["target", "context"] = Form(...) - ): - self.file = file - self.type = type FileTypeTuple = Tuple[UploadFile, Literal["target", "context"]] -FileContentTuple = Tuple[bytes, Literal["target", "context"]] \ No newline at end of file +FileContentTuple = Tuple[bytes, Literal["target", "context"]] diff --git a/tests/prod-test.sh b/tests/prod-test.sh index 35bb4af..a8aa4f5 100755 --- a/tests/prod-test.sh +++ b/tests/prod-test.sh @@ -6,6 +6,6 @@ python3 test.py --monologue \ db-context.pdf \ gs-context.pdf \ hsbc-context.pdf \ - investorpres-main.pdf \ + investorpres-main.pdf target \ jpm-context.pdf \ keybanc-context.pdf \ No newline at end of file diff --git a/tests/test.py b/tests/test.py index 064564f..974a152 100644 --- a/tests/test.py +++ b/tests/test.py @@ -8,7 +8,7 @@ import asyncio from urllib.parse import urljoin import argparse -from typing import List +from typing import List, Tuple # Add global TEST_USER_ID TEST_USER_ID = "test-userid" @@ -218,7 +218,10 @@ def test_saved_podcasts(base_url: str, job_id: str, max_retries=5, retry_delay=5 def test_api( - base_url: str, pdf_files: List[str], monologue: bool = False, vdb: bool = False + base_url: str, + pdf_files_with_types: List[Tuple[str, str]], + monologue: bool = False, + vdb: bool = False, ): voice_mapping = { "speaker-1": "iP95p4xoKVk53GoZ742B", @@ -230,19 +233,18 @@ def test_api( process_url = f"{base_url}/process_pdf" # Update path resolution - current_dir = os.path.dirname( - os.path.abspath(__file__) - ) # This gets /tests directory - project_root = os.path.dirname(current_dir) # Go up one level to project root + current_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.dirname(current_dir) samples_dir = os.path.join(project_root, "samples") - # Rest of the path handling remains the same - sample_pdf_paths = [] - for pdf_file in pdf_files: + sample_pdf_paths_with_types = [] + for pdf_file, file_type in pdf_files_with_types: if os.path.isabs(pdf_file): - sample_pdf_paths.append(pdf_file) + sample_pdf_paths_with_types.append((pdf_file, file_type)) else: - sample_pdf_paths.append(os.path.join(samples_dir, pdf_file)) + sample_pdf_paths_with_types.append( + (os.path.join(samples_dir, pdf_file), file_type) + ) # Prepare the payload with updated schema and userId transcription_params = { @@ -265,30 +267,40 @@ def test_api( ) print(f"Using voices: {voice_mapping}") - pdf_files = [open(path, "rb") for path in sample_pdf_paths] + # Prepare multipart form data + form_data = [] + file_types = [] + + # Add each file to the form data + for path, file_type in sample_pdf_paths_with_types: + with open(path, "rb") as pdf_file: + content = pdf_file.read() + form_data.append( + ("files", (os.path.basename(path), content, "application/pdf")) + ) + file_types.append(file_type) + + # Add the file types as separate form fields + for file_type in file_types: + form_data.append(("file_types", (None, file_type))) + + # Add transcription parameters + form_data.append(("transcription_params", (None, json.dumps(transcription_params)))) + try: - files = [ - ("files", (os.path.basename(path), pdf_file, "application/pdf")) - for path, pdf_file in zip(sample_pdf_paths, pdf_files) - ] - - response = requests.post( - process_url, - files=files, - data={"transcription_params": json.dumps(transcription_params)}, - ) + response = requests.post(process_url, files=form_data) assert ( response.status_code == 202 - ), f"Expected status code 202, but got {response.status_code}" + ), f"Expected status code 202, but got {response.status_code}. Response: {response.text}" job_data = response.json() assert "job_id" in job_data, "Response missing job_id" job_id = job_data["job_id"] print(f"[{datetime.now().strftime('%H:%M:%S')}] Job ID received: {job_id}") - finally: - for f in pdf_files: - f.close() + except Exception as e: + print(f"Error during PDF submission: {e}") + raise # Step 2: Start monitoring status via WebSocket monitor = StatusMonitor(base_url, job_id) @@ -340,9 +352,29 @@ def test_api( if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Process PDF files for audio conversion" + description="Process PDF files for audio conversion", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" + Examples: + # Process single file (defaults to context) + python test.py file1.pdf + + # Process single file as target + python test.py file1.pdf target + + # Process multiple files with explicit types + python test.py file1.pdf target file2.pdf context file3.pdf context + + # Process multiple files (defaulting to context) + python test.py file1.pdf target file2.pdf file3.pdf + """, + ) + + parser.add_argument( + "files", + nargs="+", + help="PDF files and their types (optional). Format: [type] [type] ...", ) - parser.add_argument("pdf_files", nargs="+", help="PDF files to process") parser.add_argument( "--api-url", default=os.getenv("API_SERVICE_URL", "http://localhost:8002"), @@ -360,10 +392,25 @@ def test_api( ) args = parser.parse_args() + + # Process the files argument to pair files with their types + pdf_files_with_types = [] + i = 0 + while i < len(args.files): + pdf_file = args.files[i] + # Check if next argument is a type specification + if i + 1 < len(args.files) and args.files[i + 1] in ["target", "context"]: + file_type = args.files[i + 1] + i += 2 + else: + file_type = "context" # default type + i += 1 + pdf_files_with_types.append((pdf_file, file_type)) + print(f"API URL: {args.api_url}") - print(f"Processing PDF files: {args.pdf_files}") + print(f"Processing PDF files: {pdf_files_with_types}") print(f"Monologue mode: {args.monologue}") print(f"VDB mode: {args.vdb}") print(f"Using test user ID: {TEST_USER_ID}") - test_api(args.api_url, args.pdf_files, args.monologue, args.vdb) + test_api(args.api_url, pdf_files_with_types, args.monologue, args.vdb)