Skip to content

Commit

Permalink
Merge pull request #452 from l3vels/fix/image-generation
Browse files Browse the repository at this point in the history
Fix/image generation
  • Loading branch information
Chkhikvadze authored Mar 18, 2024
2 parents aed1f15 + c3bc959 commit 2eeb0a2
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 35 deletions.
1 change: 0 additions & 1 deletion apps/server/controllers/voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def get_voice_options(
end = start + per_page
paginated_azure_voices = azure_voices[start:end]

print(labsResponse.json())
combined_response = {
"elevenLabsVoices": labsResponse.json(),
"playHtVoices": play_ht_all_voices,
Expand Down
35 changes: 34 additions & 1 deletion apps/server/tools/dalle/dalle.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
from typing import Optional, Type
from uuid import uuid4

import requests
from fastapi import HTTPException
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
from openai import OpenAI
from pydantic import BaseModel, Field

from services.aws_s3 import AWSS3Service
from tools.base import BaseTool
from utils.model import get_llm

Expand Down Expand Up @@ -56,11 +60,40 @@ def _run(
)
chain = LLMChain(llm=llm, prompt=prompt)

image_url = DallEAPIWrapper(api_key=self.settings.openai_api_key).run(
dalle_image_url = DallEAPIWrapper(api_key=self.settings.openai_api_key).run(
chain.run(query)
)

image_url = get_final_image_url(
image_url=dalle_image_url, account=self.account
)

return image_url

except Exception as err:
return str(err)


def get_final_image_url(image_url: str, account):
response = requests.get(image_url, stream=True)
if response.status_code == 200:
content_type = response.headers["content-type"]
image_body = response.content
name = image_url.split("/")[-1] or f"image-{uuid4()}"

if "." in name:
name, ext = name.rsplit(".", 1)
else:
# Handle the case where there is no file extension
# You might want to assign a default extension or raise an error
ext = "png"

key = f"account_{account.id}/files/dalle-{uuid4()}.{ext}"

final_url = AWSS3Service.upload(
key=key, content_type=content_type, body=image_body
)

return final_url
else:
raise HTTPException(status_code=400, detail="Could not retrieve image from URL")
1 change: 1 addition & 0 deletions apps/ui/src/hooks/useApollo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ const useApollo = () => {
return new ApolloClient({
link: apolloLink,
cache: new InMemoryCache(),
connectToDevTools: true,

defaultOptions: {
watchQuery: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,30 @@ const ChatMessageListV2 = ({
}
}, [initialChat])

const appendInterval = useRef(null as any)
const [atBottom, setAtBottom] = useState(false)
const showButtonTimeoutRef = useRef(null as any)
const [showButton, setShowButton] = useState(false)

useEffect(() => {
if (!showScrollButton) scrollToEnd()
}, [data])
return () => {
clearInterval(appendInterval.current)
clearTimeout(showButtonTimeoutRef.current)
}
}, [])

useEffect(() => {
clearTimeout(showButtonTimeoutRef.current)
if (!atBottom) {
showButtonTimeoutRef.current = setTimeout(() => setShowButton(true), 500)
} else {
setShowButton(false)
}
}, [atBottom, setShowButton])

// useEffect(() => {
// if (!showButton) scrollToEnd()
// }, [data[data?.length - 1]])

return (
<StyledRoot show={true}>
Expand All @@ -191,6 +212,12 @@ const ChatMessageListV2 = ({
data={initialChat}
totalCount={data.length}
overscan={2500}
followOutput
atBottomStateChange={bottom => {
clearInterval(appendInterval.current)
setAtBottom(bottom)
}}
atBottomThreshold={6}
components={{
Footer: () => {
return (
Expand Down Expand Up @@ -284,7 +311,8 @@ const ChatMessageListV2 = ({
</>
)}
/>
{showScrollButton && (

{showButton && (
<StyledScrollButton onClick={scrollToEnd}>
<ArrowDown size={18} />
</StyledScrollButton>
Expand Down Expand Up @@ -318,7 +346,7 @@ const StyledWrapper = styled.div<{ isHidden?: boolean; isReplying?: boolean }>`
/* align-items: center; */
gap: 5px;
padding-top: 10px;
padding-bottom: 6px;
/* margin-right: 50px; */
.visible-reply {
opacity: 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,56 @@ import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'
import { atomDark } from 'react-syntax-highlighter/dist/esm/styles/prism'
import remarkGfm from 'remark-gfm'
import { useModal } from 'hooks'
import { memo } from 'react'
import { memo, useEffect, useState } from 'react'
import { t } from 'i18next'
import Loader from 'share-ui/components/Loader/Loader'

const YOUTUBE_REGEX = /^https:\/\/www\.youtube\.com\/watch\?v=([a-zA-Z0-9_-]+)&/
const IMAGE_REGEX = /\.(gif|jpe?g|tiff?|png|webp|bmp)$/i
// const SETTINGS_REGEX = /\/setting/
const TOOLKIT_REGEX = /\/toolkits\/[^/]+/
const VOICE_REGEX = /\/integrations\/voice\/[^/]+/
const SETTINGS_REGEX = /\/integrations\?setting=([^/]+)/
const DALLE_IMAGE_REGEX =
/https:\/\/oaidalleapiprodscus\.blob\.core\.windows\.net\/private\/or.*?&sig=[\w%/]+/

const INCOMPLETE_URL_REGEX = /https?:\/\/\S*$/

const AiMessageMarkdown = ({ isReply = false, children }: { isReply?: boolean; children: any }) => {
const { openModal } = useModal()

const [loadingUrl, setLoadingUrl] = useState(false)

useEffect(() => {
// Check if the last part of the children string is an incomplete URL
const match = children.match(INCOMPLETE_URL_REGEX)
if (match) {
// If there's an incomplete URL, start loading
setLoadingUrl(true)
// Here you would have logic to determine when the URL is complete
// For demonstration, let's assume the URL is complete when there are no more tokens coming for 1 second
const timeoutId = setTimeout(() => {
setLoadingUrl(false)
}, 1000) // Wait for 1 second of no new tokens to consider the URL complete

// Clear the timeout if the component unmounts or the children update before the timeout is reached
return () => clearTimeout(timeoutId)
} else {
// If there's no incomplete URL, ensure loading is not shown
setLoadingUrl(false)
}
}, [children])

return (
<StyledReactMarkdown
isReply={isReply}
children={children}
remarkPlugins={[[remarkGfm, { singleTilde: false }]]}
components={{
img: ({ node, ...props }) => <StyledImg {...props} />,
table: ({ node, ...props }) => <StyledTable {...props} />,
a: ({ href, children }) => {
// console.log('href', href)
// console.log('children', children)
if (loadingUrl) {
return <Loader size={40} />
}

if (YOUTUBE_REGEX.test(href as string)) {
const videoId = (href as string).match(YOUTUBE_REGEX)?.[1]
Expand All @@ -51,12 +76,6 @@ const AiMessageMarkdown = ({ isReply = false, children }: { isReply?: boolean; c

if (IMAGE_REGEX.test(href as string)) {
const imageUrl = href as string
return <img src={imageUrl} alt={children as string} />
}

if (DALLE_IMAGE_REGEX.test(href as string)) {
const imageUrl = href as string

return <StyledImg src={imageUrl} alt={children as string} />
}

Expand Down
2 changes: 1 addition & 1 deletion apps/ui/src/modals/AIChatModal/components/ChatV2.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ const ChatV2 = ({ chatSessionId }: { chatSessionId?: string }) => {

const agentId = urlParams.get('agent')
const teamId = urlParams.get('team')
const chatId = urlParams.get('chat') || chatSessionId
const chatId = urlParams.get('chat') || urlParams.get('session') || chatSessionId

const { thinking, setThinking, socket } = useChatState()

Expand Down
4 changes: 3 additions & 1 deletion apps/ui/src/modals/AIChatModal/hooks/useChatSocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const useChatSocket = ({ userId, createdChatId }: ChatSocketProps) => {

const agentId = urlParams.get('agent')
const teamId = urlParams.get('team')
const chatId = urlParams.get('chat') || createdChatId || ''
const chatId = urlParams.get('chat') || urlParams.get('session') || createdChatId || ''

const groupId = getSessionId({
user,
Expand All @@ -46,6 +46,7 @@ const useChatSocket = ({ userId, createdChatId }: ChatSocketProps) => {

const response = await fetch(url)
const data = await response.json()

return data.url
}, [user?.id])

Expand All @@ -68,6 +69,7 @@ const useChatSocket = ({ userId, createdChatId }: ChatSocketProps) => {
onUserTypingEvent(e)
}
if (data.type === 'CHAT_MESSAGE_ADDED') {
console.log('socket', data)
upsertChatMessageInCache(data.chat_message, {
agentId: data.agent_id,
teamId: data.team_id,
Expand Down
1 change: 1 addition & 0 deletions apps/ui/src/modals/AIChatModal/hooks/useUpdateChatCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const useUpdateChatCache = () => {
isNil,
)
}

apolloClient.cache.updateQuery(
{ query: CHAT_MESSAGES_GQL, variables: queryVariables },
data => {
Expand Down
7 changes: 6 additions & 1 deletion apps/ui/src/pages/Agents/AgentTables/AgentSessionsTable.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import Table from 'components/Table'

import ChatV2 from 'modals/AIChatModal/components/ChatV2'
import { ChatContextProvider } from 'modals/AIChatModal/context/ChatContext'
import { StyledCloseIcon } from 'pages/Home/GetStarted/GetStartedContainer'

import { useColumn } from 'pages/Sessions/columnConfig'
Expand Down Expand Up @@ -68,7 +69,11 @@ const AgentSessionsTable = ({ agentId }: { agentId: string }) => {
kind={IconButton.kinds?.TERTIARY}
size={IconButton.sizes?.SMALL}
/>
{sessionId && <ChatV2 chatSessionId={sessionId} />}
{sessionId && (
<ChatContextProvider>
<ChatV2 />
</ChatContextProvider>
)}
</StyledChatWrapper>
</StyledRoot>
)
Expand Down
33 changes: 18 additions & 15 deletions apps/ui/src/services/chat/useChatMessagesService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,27 @@ type UseChatMessagesService = {
chatId?: Nullable<string>
}

export const useChatMessagesService = ({
agentId,
teamId,
chatId,
}: UseChatMessagesService) => {
const { data, error, loading, refetch } = useQuery(CHAT_MESSAGES_GQL, {
// Omit undefined variables to exclude in query params
variables: omitBy(
export const useChatMessagesService = ({ agentId, teamId, chatId }: UseChatMessagesService) => {
let queryVariables = omitBy(
{
agent_id: agentId,
team_id: teamId,
chat_id: chatId,
},
isNil,
)
if (chatId) {
queryVariables = omitBy(
{
agent_id: agentId,
team_id: teamId,
chat_id: chatId,
},
isNil,
),
)
}

const { data, error, loading, refetch } = useQuery(CHAT_MESSAGES_GQL, {
// Omit undefined variables to exclude in query params
variables: queryVariables,
})

return {
Expand All @@ -37,10 +43,7 @@ export const useChatMessagesService = ({
}
}

export const useChatMessagesHistoryService = ({
agentId,
teamId,
}: UseChatMessagesService) => {
export const useChatMessagesHistoryService = ({ agentId, teamId }: UseChatMessagesService) => {
const { data, error, loading, refetch } = useQuery(CHAT_MESSAGES_HISTORY_GQL, {
// Omit undefined variables to exclude in query params
variables: omitBy(
Expand Down

0 comments on commit 2eeb0a2

Please sign in to comment.