Skip to content

Commit

Permalink
cleanup: refactor shared types, deduplicate duplicated types (#41)
Browse files Browse the repository at this point in the history
* refactor shared types, deduplicate duplicated types

* minor path fixes

---------

Co-authored-by: Andrew Wang <[email protected]>
  • Loading branch information
aw632 and Andrew Wang authored Nov 19, 2024
1 parent 1009456 commit 71cecc1
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 129 deletions.
13 changes: 3 additions & 10 deletions services/APIService/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,9 @@
WebSocketDisconnect,
Query,
)
from shared.shared_types import (
ServiceType,
JobStatus,
StatusUpdate,
TranscriptionParams,
SavedPodcast,
SavedPodcastWithAudio,
Conversation,
PromptTracker,
)
from shared.api_types import ServiceType, JobStatus, StatusUpdate, TranscriptionParams
from shared.prompt_types import PromptTracker
from shared.podcast_types import SavedPodcast, SavedPodcastWithAudio, Conversation
from shared.connection import ConnectionManager
from shared.storage import StorageManager
from shared.otel import OpenTelemetryInstrumentation, OpenTelemetryConfig
Expand Down
7 changes: 3 additions & 4 deletions services/AgentService/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from fastapi import FastAPI, BackgroundTasks, HTTPException
from shared.shared_types import (
from shared.api_types import (
ServiceType,
JobStatus,
Conversation,
TranscriptionRequest,
PodcastOutline,
)
from shared.podcast_types import Conversation, PodcastOutline
from shared.api_types import TranscriptionRequest
from podcast_flow import (
podcast_summarize_pdfs,
podcast_generate_raw_outline,
Expand Down
9 changes: 3 additions & 6 deletions services/AgentService/monologue_flow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from shared.shared_types import (
JobStatus,
Conversation,
PDFMetadata,
TranscriptionRequest,
)
from shared.api_types import JobStatus, TranscriptionRequest
from shared.podcast_types import Conversation
from shared.pdf_types import PDFMetadata
from shared.llmmanager import LLMManager
from shared.job import JobStatusManager
from typing import List, Dict
Expand Down
10 changes: 3 additions & 7 deletions services/AgentService/podcast_flow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from shared.shared_types import (
JobStatus,
Conversation,
PDFMetadata,
TranscriptionRequest,
PodcastOutline,
)
from shared.pdf_types import PDFMetadata
from shared.podcast_types import Conversation, PodcastOutline
from shared.api_types import JobStatus, TranscriptionRequest
from shared.llmmanager import LLMManager
from shared.job import JobStatusManager
from typing import List, Dict, Any, Coroutine
Expand Down
3 changes: 2 additions & 1 deletion services/AgentService/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import ujson as json
import os
import time
from shared.shared_types import TranscriptionRequest, PDFMetadata
from shared.api_types import TranscriptionRequest
from shared.pdf_types import PDFMetadata


def test_transcribe_api():
Expand Down
31 changes: 5 additions & 26 deletions services/PDFService/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from fastapi import FastAPI, File, UploadFile, BackgroundTasks, HTTPException, Form
from shared.shared_types import ServiceType, JobStatus, StatusResponse
from shared.job import JobStatusManager
from shared.otel import OpenTelemetryInstrumentation, OpenTelemetryConfig
from opentelemetry.trace.status import StatusCode
Expand All @@ -10,9 +9,8 @@
import asyncio
import ujson as json
from typing import List
from pydantic import BaseModel, Field
from enum import Enum
from datetime import datetime
from shared.pdf_types import PDFConversionResult, ConversionStatus, PDFMetadata
from shared.api_types import ServiceType, JobStatus, StatusResponse

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -31,31 +29,12 @@
job_manager = JobStatusManager(ServiceType.PDF, telemetry=telemetry)

# Configuration
MODEL_API_URL = os.getenv("MODEL_API_URL", "https://pdf-gyrdps568.brevlab.com")
MODEL_API_URL = os.getenv(
"MODEL_API_URL", "https://nv-ingest-rest-endpoint.brevlab.com/v1"
)
DEFAULT_TIMEOUT = 600 # seconds


class ConversionStatus(str, Enum):
SUCCESS = "success"
FAILED = "failed"


class PDFConversionResult(BaseModel):
filename: str
content: str = ""
status: ConversionStatus
error: str | None = None


class PDFMetadata(BaseModel):
filename: str
markdown: str = ""
summary: str = ""
status: ConversionStatus
error: str | None = None
created_at: datetime = Field(default_factory=datetime.utcnow)


async def convert_pdfs_to_markdown(pdf_paths: List[str]) -> List[PDFConversionResult]:
"""Convert multiple PDFs to Markdown using the external API service"""
logger.info(f"Sending {len(pdf_paths)} PDFs to external conversion service")
Expand Down
2 changes: 1 addition & 1 deletion services/PDFService/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
from typing import Optional, List
from shared.shared_types import StatusResponse
from shared.api_types import StatusResponse
import sys
from pathlib import Path

Expand Down
2 changes: 1 addition & 1 deletion services/TTSService/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi import FastAPI, BackgroundTasks, HTTPException
from shared.shared_types import ServiceType, JobStatus
from shared.api_types import ServiceType, JobStatus
from shared.job import JobStatusManager
from fastapi.responses import Response
from pydantic import BaseModel
Expand Down
78 changes: 8 additions & 70 deletions shared/shared/shared_types.py → shared/shared/api_types.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from pydantic import BaseModel, Field, model_validator
from typing import Optional, Dict, List, Literal
from typing import Optional, Dict, List
from .pdf_types import PDFMetadata
from enum import Enum


class ServiceType(str, Enum):
PDF = "pdf"
AGENT = "agent"
TTS = "tts"


class JobStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"


class ServiceType(str, Enum):
PDF = "pdf"
AGENT = "agent"
TTS = "tts"


class StatusUpdate(BaseModel):
job_id: str
status: JobStatus
Expand All @@ -32,48 +33,6 @@ class StatusResponse(BaseModel):
message: Optional[str] = None


class SavedPodcast(BaseModel):
job_id: str
filename: str
created_at: str
size: int
transcription_params: Optional[Dict] = {}


class SavedPodcastWithAudio(SavedPodcast):
audio_data: str


# Transcript schema
class DialogueEntry(BaseModel):
text: str
speaker: Literal["speaker-1", "speaker-2"]


class Conversation(BaseModel):
scratchpad: str
dialogue: List[DialogueEntry]


# Prompt tracker schema
class ProcessingStep(BaseModel):
step_name: str
prompt: str
response: str
model: str
timestamp: float


class PromptTracker(BaseModel):
steps: List[ProcessingStep]


class PDFMetadata(BaseModel):
filename: str
markdown: str
summary: str = ""


class TranscriptionParams(BaseModel):
userId: str = Field(..., description="KAS User ID")
name: str = Field(..., description="Name of the podcast")
Expand Down Expand Up @@ -134,24 +93,3 @@ def validate_monologue_settings(self) -> "TranscriptionParams":
class TranscriptionRequest(TranscriptionParams):
pdf_metadata: List[PDFMetadata]
job_id: str


class SegmentPoint(BaseModel):
description: str


class SegmentTopic(BaseModel):
title: str
points: List[SegmentPoint]


class PodcastSegment(BaseModel):
section: str
topics: List[SegmentTopic]
duration: int
references: List[str]


class PodcastOutline(BaseModel):
title: str
segments: List[PodcastSegment]
2 changes: 1 addition & 1 deletion shared/shared/job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from shared.shared_types import ServiceType
from shared.api_types import ServiceType
from shared.otel import OpenTelemetryInstrumentation
import redis
import time
Expand Down
25 changes: 25 additions & 0 deletions shared/shared/pdf_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
from enum import Enum


class ConversionStatus(str, Enum):
SUCCESS = "success"
FAILED = "failed"


class PDFConversionResult(BaseModel):
filename: str
content: str = ""
status: ConversionStatus
error: Optional[str] = None


class PDFMetadata(BaseModel):
filename: str
markdown: str = ""
summary: str = ""
status: ConversionStatus
error: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
45 changes: 45 additions & 0 deletions shared/shared/podcast_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pydantic import BaseModel
from typing import Optional, Dict, Literal, List


class SavedPodcast(BaseModel):
job_id: str
filename: str
created_at: str
size: int
transcription_params: Optional[Dict] = {}


class SavedPodcastWithAudio(SavedPodcast):
audio_data: str


class DialogueEntry(BaseModel):
text: str
speaker: Literal["speaker-1", "speaker-2"]


class Conversation(BaseModel):
scratchpad: str
dialogue: List[DialogueEntry]


class SegmentPoint(BaseModel):
description: str


class SegmentTopic(BaseModel):
title: str
points: List[SegmentPoint]


class PodcastSegment(BaseModel):
section: str
topics: List[SegmentTopic]
duration: int
references: List[str]


class PodcastOutline(BaseModel):
title: str
segments: List[PodcastSegment]
2 changes: 1 addition & 1 deletion shared/shared/prompt_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import logging
from .storage import StorageManager
from .shared_types import ProcessingStep, PromptTracker as PromptTrackerModel
from .prompt_types import ProcessingStep, PromptTracker as PromptTrackerModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down
14 changes: 14 additions & 0 deletions shared/shared/prompt_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pydantic import BaseModel
from typing import List


class ProcessingStep(BaseModel):
step_name: str
prompt: str
response: str
model: str
timestamp: float


class PromptTracker(BaseModel):
steps: List[ProcessingStep]
2 changes: 1 addition & 1 deletion shared/shared/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import base64
from minio import Minio
from minio.error import S3Error
from shared.shared_types import TranscriptionParams
from shared.api_types import TranscriptionParams
from shared.otel import OpenTelemetryInstrumentation
from opentelemetry.trace.status import StatusCode
import os
Expand Down

0 comments on commit 71cecc1

Please sign in to comment.