Skip to content

Commit

Permalink
Merge pull request #52 from rimgosu/main
Browse files Browse the repository at this point in the history
feat: tts, gpt
  • Loading branch information
rimgosu authored Feb 15, 2024
2 parents 44405f2 + 707d9d5 commit 8a46ae5
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/fastapi-env/
/temp/
/temp/voice_temp
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
8 changes: 4 additions & 4 deletions app/api/api_v1/endpoints/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from typing import List

from app.database import get_db
from app.services.log_service import (create_log_service, get_log_service, get_logs_by_gree_service, get_logs_service, delete_log_service)
from app.schemas.LogDto import CreateLogDto, LogResponseDto
from app.services.log_service import (create_usertalk_log_service, get_log_service, get_logs_by_gree_service, get_logs_service, delete_log_service)
from app.schemas.LogDto import CreateUserTalkLogDto, LogResponseDto

router = APIRouter()

@router.post("/", response_model=LogResponseDto)
async def create_log(log_dto: CreateLogDto, db: AsyncSession = Depends(get_db)):
return await create_log_service(db, log_dto)
async def create_log(log_dto: CreateUserTalkLogDto, db: AsyncSession = Depends(get_db)):
return await create_usertalk_log_service(db, log_dto)

@router.get("/{log_id}", response_model=LogResponseDto)
async def read_log(log_id: int, db: AsyncSession = Depends(get_db)):
Expand Down
3 changes: 2 additions & 1 deletion app/crud/crud_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from app.models.models import Log # 모델 파일 경로에 따라 수정해야 할 수 있음

# 로그 생성
async def create_log(db: AsyncSession, gree_id: int, log_type, content: str):
async def create_log(db: AsyncSession, gree_id: int, log_type, content: str, voice_url: str):
db_log = Log(
gree_id=gree_id,
log_type=log_type,
content=content,
voice_url=voice_url,
register_at=datetime.now()
)
db.add(db_log)
Expand Down
9 changes: 9 additions & 0 deletions app/models/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,12 @@ class FileTypeEnum(PyEnum):
class EmotionTypeEnum(PyEnum):
HAPPY = "HAPPY"
UNHAPPY = "UNHAPPY"

@unique
class VoiceTypeEnum(PyEnum):
ALLOY = 'alloy'
ECHO = 'echo'
FABLE = 'fable'
ONYX = 'onyx'
NOVA = 'nova'
SHIMMER = 'shimmer'
6 changes: 5 additions & 1 deletion app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy import Column, Integer, String, DateTime, Enum, ForeignKey, Boolean
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from app.models.enums import RoleEnum, StatusEnum, GradeEnum, LogTypeEnum, FileTypeEnum
from app.models.enums import RoleEnum, StatusEnum, GradeEnum, LogTypeEnum, FileTypeEnum, VoiceTypeEnum
from datetime import datetime

Base = declarative_base()
Expand Down Expand Up @@ -34,6 +34,8 @@ class Gree(Base):
prompt_mbti = Column(String(255))
status = Column(Enum(StatusEnum))
isFavorite = Column(Boolean, default=False)
# 그리는 프롬프트 엔지니어링을 통해 다음과 같이 TTS의 목소리가 결정되어야한다.
voice_type = Column(Enum(VoiceTypeEnum), default=VoiceTypeEnum.ALLOY)
register_at = Column(DateTime, nullable=False, default=datetime.now())

member = relationship("Member", back_populates="gree")
Expand Down Expand Up @@ -62,6 +64,8 @@ class Log(Base):
gree_id = Column(Integer, ForeignKey('gree.gree_id'), nullable=False)
log_type = Column(Enum(LogTypeEnum), nullable=False)
content = Column(String(1000))
# "GREE_TALK"일 경우 음성데이터가 Azure에 업로드 되어야한다.
voice_url = Column(String(255))
register_at = Column(DateTime, nullable=False, default=datetime.now())

gree = relationship("Gree", back_populates="log")
Expand Down
8 changes: 7 additions & 1 deletion app/schemas/LogDto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
from app.models.enums import LogTypeEnum

# 생성을 위한 DTO
class CreateLogDto(BaseModel):
class CreateUserTalkLogDto(BaseModel):
gree_id: int
log_type: LogTypeEnum
content: str

class CreateGreeTalkLogDto(BaseModel):
gree_id: int
log_type: LogTypeEnum
content: str
voice_url: str # azure

# 응답을 위한 DTO
class LogResponseDto(BaseModel):
id: int
Expand Down
2 changes: 1 addition & 1 deletion app/schemas/gree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ class Gree(BaseModel):
prompt_mbti: Optional[str] = None
status: StatusEnum
isFavorite: bool

class Config:
from_attributes = True
94 changes: 84 additions & 10 deletions app/services/ai_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
import uuid
import aiohttp
import openai
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from azure.storage.blob import BlobServiceClient, ContentSettings

from app.crud.crud_gree import crud_get_gree_by_id_only
from app.models.enums import VoiceTypeEnum
from app.schemas.ChatDto import ChatRequestDto
from app.schemas.LogDto import CreateLogDto
from app.services.log_service import create_log_service
from app.schemas.LogDto import CreateGreeTalkLogDto, CreateUserTalkLogDto
from app.services.log_service import create_greetalk_log_service, create_usertalk_log_service

# chat 테스트를 위한 서비스이다.
async def chat_with_openai_test_service(chat_request):
Expand Down Expand Up @@ -47,10 +52,13 @@ async def chat_with_openai_test_service(chat_request):
# 1. gree_id로 gree를 하나 조회한다. ✅
# 2. gree의 정보를 토대로 chat_request에 있는 정보를 채워넣는다. ✅
# 3. 즉 엔드포인트에서 받아야 할 정보는 gree_id, message 이 둘이면 된다. ✅
# 4. gree_id, message로 입력받은 것은 "USER_TALK" Log로 남아야 한다.
# 5. gree_id, message로 출력되는 것은 "GREE_TALK" Log로 데이터베이스에 저장되어야 한다.
# 4. gree_id, message로 입력받은 것은 "USER_TALK" Log로 남아야 한다. ✅
# 5. gree_id, message로 출력되는 것은 "GREE_TALK" Log로 데이터베이스에 저장되어야 한다. ✅
# 6. "GREE_TALK" 메시지는 TTS과정을 거쳐야한다.
# 7. TTS를 거친 메시지는 Azure에 저장되어야한다.
# 8. Azure에 저장된 URL은 Log에 저장되어야한다.
# 9. GREE의 성격을 지정할때(Update Gree) Gree객체의 VoiceTypeEnum이 지정되어야한다.
async def chat_with_openai_service(db: AsyncSession, chat_request: ChatRequestDto):

gree = await crud_get_gree_by_id_only(db, chat_request.gree_id)
system_message = (
"이 대화는 한국어로 진행됩니다. 모든 응답은 한국어로 제공되어야 합니다. "
Expand All @@ -67,13 +75,13 @@ async def chat_with_openai_service(db: AsyncSession, chat_request: ChatRequestDt
f"당신은 {gree.prompt_gender}, {gree.prompt_age}살, 이름은 {gree.gree_name}, MBTI는 각각의 성향이 강하게 나타나는 {gree.prompt_mbti}입니다."
)

createUserLogDto = CreateLogDto(
createUserLogDto = CreateUserTalkLogDto(
gree_id=gree.id,
log_type='USER_TALK',
content=chat_request.message
)

user_talk = await create_log_service(db ,createUserLogDto)
user_talk = await create_usertalk_log_service(db ,createUserLogDto)

try:
completion = openai.ChatCompletion.create(
Expand All @@ -87,16 +95,82 @@ async def chat_with_openai_service(db: AsyncSession, chat_request: ChatRequestDt
{"role": "user", "content": chat_request.message}
]
)
gree_talk= completion.choices[0].message.content

mp3_path = await text_to_speech(gree.voice_type, gree_talk)
voice_url_azure= await upload_mp3_azure(mp3_path)

createGptLogDto = CreateLogDto(
createGptLogDto = CreateGreeTalkLogDto(
gree_id=gree.id,
log_type='GREE_TALK',
content=completion.choices[0].message.content
content=gree_talk,
voice_url=voice_url_azure
)

gpt_talk = await create_log_service(db ,createGptLogDto)
gpt_talk = await create_greetalk_log_service(db ,createGptLogDto)

return {"user_talk": user_talk, "gpt_talk": gpt_talk}
except Exception as exc:
print(f"An error occurred: {exc}")
raise HTTPException(status_code=500, detail="An error occurred while processing your request.")


async def text_to_speech(voice_type: VoiceTypeEnum, gree_talk: str) -> str:
url = "https://api.openai.com/v1/audio/speech"
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') # 환경변수에서 API 키를 가져옵니다.

payload = {
"model": "tts-1",
"input": gree_talk,
"voice": voice_type.value
}
headers = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
}

async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
if response.status == 200:
content = await response.read()

# 지정된 경로에 디렉토리가 없으면 생성
dir_path = '/temp/voice_temp/'
if not os.path.exists(dir_path):
os.makedirs(dir_path)

# 파일 저장 경로 변경
mp3_path = os.path.join(dir_path, f'{voice_type.value}_speech.mp3')

with open(mp3_path, 'wb') as f:
f.write(content)
return mp3_path
else:
return "Error: 응답 상태가 200이 아닙니다."


async def upload_mp3_azure(mp3_path: str) -> str:
try:
container_name = "greefile"
AZURE_ACCOUNT_KEY = os.getenv("AZURE_ACCOUNT_KEY")
if not AZURE_ACCOUNT_KEY:
raise HTTPException(status_code=500, detail="Azure account key is not set in environment variables.")

connection_string = f"DefaultEndpointsProtocol=https;AccountName=greedotstorage;AccountKey={AZURE_ACCOUNT_KEY};EndpointSuffix=core.windows.net"
blob_service_client = BlobServiceClient.from_connection_string(connection_string)

# 파일 이름 추출 및 고유한 파일 이름 생성
file_name = os.path.basename(mp3_path)
unique_file_name = f"logVoices/{uuid.uuid4()}_{file_name}"

blob_client = blob_service_client.get_blob_client(container=container_name, blob=unique_file_name)

# 파일 업로드
with open(mp3_path, "rb") as data:
blob_client.upload_blob(data, overwrite=True, content_settings=ContentSettings(content_type='audio/mpeg'))

# 업로드된 파일의 URL 반환
return blob_client.url
except Exception as e:
print(f"An error occurred while uploading MP3 to Azure: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred while uploading MP3 to Azure: {e}")
12 changes: 8 additions & 4 deletions app/services/log_service.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
# services/log_service.py
from typing import List, Optional
from sqlalchemy import null
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import HTTPException
from app.crud.crud_log import (create_log as crud_create_log,
get_log as crud_get_log,
get_logs as crud_get_logs,
update_log as crud_update_log,
delete_log as crud_delete_log)
from app.schemas.LogDto import CreateLogDto
from app.schemas.LogDto import CreateGreeTalkLogDto, CreateUserTalkLogDto
from app.models.models import Log
from sqlalchemy.future import select

async def create_log_service(db: AsyncSession, log_dto: CreateLogDto) -> Log:
db_log = await crud_create_log(db, log_dto.gree_id, log_dto.log_type, log_dto.content)
async def create_usertalk_log_service(db: AsyncSession, log_dto: CreateUserTalkLogDto) -> Log:
db_log = await crud_create_log(db, log_dto.gree_id, log_dto.log_type, log_dto.content, null)
return db_log

async def create_greetalk_log_service(db: AsyncSession, log_dto: CreateGreeTalkLogDto) -> Log:
db_log = await crud_create_log(db, log_dto.gree_id, log_dto.log_type, log_dto.content, log_dto.voice_url)
return db_log

async def get_log_service(db: AsyncSession, log_id: int) -> Optional[Log]:
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@
app.include_router(api_router, prefix=settings.API_v1_STR)

if __name__ == "__main__":
# init_db()
init_db()
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, log_level="debug")

0 comments on commit 8a46ae5

Please sign in to comment.