From 695f97907d97b9ef418042d610d7145f1a729197 Mon Sep 17 00:00:00 2001 From: fegloff Date: Tue, 16 Jan 2024 19:46:20 -0500 Subject: [PATCH] add stream completion for vision + add vision command to work with multiple img url --- src/modules/llms/index.ts | 18 +++++----- src/modules/open-ai/api/openAi.ts | 39 +++------------------ src/modules/open-ai/helpers.ts | 17 ++++----- src/modules/open-ai/index.ts | 57 +++++++++++++++++++------------ src/modules/types.ts | 1 + 5 files changed, 60 insertions(+), 72 deletions(-) diff --git a/src/modules/llms/index.ts b/src/modules/llms/index.ts index fbd0d18c..4696ddad 100644 --- a/src/modules/llms/index.ts +++ b/src/modules/llms/index.ts @@ -85,7 +85,7 @@ export class LlmsBot implements PayableBot { return undefined } - private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string | undefined { + private isSupportedUrlReply (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined { return getUrlFromText(ctx) } @@ -251,14 +251,16 @@ export class LlmsBot implements PayableBot { async onUrlReplyHandler (ctx: OnMessageContext | OnCallBackQueryData): Promise { try { - const url = getUrlFromText(ctx) ?? '' - const prompt = ctx.message?.text ?? 'summarize' - const collection = ctx.session.collections.activeCollections.find(c => c.url === url) - const newPrompt = `${prompt}` // ${url} - if (collection) { - await this.queryUrlCollection(ctx, url, newPrompt) + const url = getUrlFromText(ctx) + if (url) { + const prompt = ctx.message?.text ?? 'summarize' + const collection = ctx.session.collections.activeCollections.find(c => c.url === url[0]) + const newPrompt = `${prompt}` // ${url} + if (collection) { + await this.queryUrlCollection(ctx, url[0], newPrompt) + } + ctx.transient.analytics.actualResponseTime = now() } - ctx.transient.analytics.actualResponseTime = now() } catch (e: any) { await this.onError(ctx, e) } diff --git a/src/modules/open-ai/api/openAi.ts b/src/modules/open-ai/api/openAi.ts index d64957a8..ad34603d 100644 --- a/src/modules/open-ai/api/openAi.ts +++ b/src/modules/open-ai/api/openAi.ts @@ -18,7 +18,7 @@ import { ChatGPTModelsEnum } from '../types' import type fs from 'fs' -import { type ChatCompletionMessageParam, type ChatCompletionCreateParamsNonStreaming } from 'openai/resources/chat/completions' +import { type ChatCompletionMessageParam } from 'openai/resources/chat/completions' import { type Stream } from 'openai/streaming' const openai = new OpenAI({ apiKey: config.openAiKey }) @@ -50,34 +50,6 @@ export async function postGenerateImg ( return response.data } -export async function imgInquiryWithVision ( - img: string, - prompt: string, - ctx: OnMessageContext | OnCallBackQueryData -): Promise { - console.log(img, prompt) - const payLoad = { - model: 'gpt-4-vision-preview', - messages: [ - { - role: 'user', - content: [ - { type: 'text', text: 'What’s in this image?' }, - { - type: 'image_url', - image_url: { url: img } - } - ] - } - ], - max_tokens: 300 - } - console.log('HELLO') - const response = await openai.chat.completions.create(payLoad as unknown as ChatCompletionCreateParamsNonStreaming) - console.log(response.choices[0].message?.content) - return 'hi' -} - export async function alterGeneratedImg ( prompt: string, filePath: string, @@ -207,11 +179,10 @@ export const streamChatCompletion = async ( } export const streamChatVisionCompletion = async ( - conversation: ChatConversation[], ctx: OnMessageContext | OnCallBackQueryData, model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW, prompt: string, - imgUrl: string, + imgUrls: string[], msgId: number, limitTokens = true ): Promise => { @@ -224,10 +195,10 @@ export const streamChatVisionCompletion = async ( role: 'user', content: [ { type: 'text', text: prompt }, - { + ...imgUrls.map(img => ({ type: 'image_url', - image_url: { url: imgUrl } - } + image_url: { url: img } + })) ] } ], diff --git a/src/modules/open-ai/helpers.ts b/src/modules/open-ai/helpers.ts index 468be8d8..5a858a41 100644 --- a/src/modules/open-ai/helpers.ts +++ b/src/modules/open-ai/helpers.ts @@ -9,7 +9,7 @@ import { isValidUrl } from './utils/web-crawler' export const SupportedCommands = { chat: { name: 'chat' }, ask: { name: 'ask' }, - // sum: { name: 'sum' }, + vision: { name: 'vision' }, ask35: { name: 'ask35' }, new: { name: 'new' }, gpt4: { name: 'gpt4' }, @@ -263,13 +263,14 @@ export const limitPrompt = (prompt: string): string => { return `${prompt} in around ${config.openAi.chatGpt.wordLimit} words` } -export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string | undefined => { - const entities = ctx.message?.reply_to_message?.entities - if (entities) { - const urlEntity = entities.find(e => e.type === 'url') - if (urlEntity) { - const url = ctx.message?.reply_to_message?.text?.slice(urlEntity.offset, urlEntity.offset + urlEntity.length) - return url +export const getUrlFromText = (ctx: OnMessageContext | OnCallBackQueryData): string[] | undefined => { + const entities = ctx.message?.entities ? ctx.message?.entities : ctx.message?.reply_to_message?.entities + const text = ctx.message?.text ? ctx.message?.text : ctx.message?.reply_to_message?.text + if (entities && text) { + const urlEntity = entities.filter(e => e.type === 'url') + if (urlEntity.length > 0) { + const urls = urlEntity.map(e => text.slice(e.offset, e.offset + e.length)) + return urls } } return undefined diff --git a/src/modules/open-ai/index.ts b/src/modules/open-ai/index.ts index 805528ec..746cc18c 100644 --- a/src/modules/open-ai/index.ts +++ b/src/modules/open-ai/index.ts @@ -29,6 +29,7 @@ import { sleep } from '../sd-images/utils' import { getMessageExtras, getPromptPrice, + getUrlFromText, hasChatPrefix, hasDallePrefix, hasNewPrefix, @@ -229,6 +230,24 @@ export class OpenAIBot implements PayableBot { return } + if (ctx.hasCommand(SupportedCommands.vision.name)) { + const photoUrl = getUrlFromText(ctx) + if (photoUrl) { + const prompt = ctx.match + ctx.session.openAi.imageGen.imgRequestQueue.push({ + prompt, + photoUrl, + command: !isNaN(+prompt) ? 'alter' : 'vision' + }) + if (!ctx.session.openAi.imageGen.isProcessingQueue) { + ctx.session.openAi.imageGen.isProcessingQueue = true + await this.onImgRequestHandler(ctx).then(() => { + ctx.session.openAi.imageGen.isProcessingQueue = false + }) + } + } + } + if ( ctx.hasCommand([SupportedCommands.dalle.name, SupportedCommands.dalleImg.name, @@ -560,7 +579,7 @@ export class OpenAIBot implements PayableBot { } else if (img?.command === 'alter') { await this.onAlterImage(img?.photo, img?.prompt, ctx) } else { - await this.onInquiryImage(img?.photo, img?.prompt, ctx) + await this.onInquiryImage(img?.photo, img?.photoUrl, img?.prompt, ctx) } ctx.chatAction = null } else { @@ -609,17 +628,23 @@ export class OpenAIBot implements PayableBot { } } - onInquiryImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise => { + onInquiryImage = async (photo: PhotoSize[] | undefined, photoUrl: string[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise => { try { if (ctx.session.openAi.imageGen.isEnabled) { - const fileId = photo?.pop()?.file_id // with pop() get full image quality - if (!fileId) { - await ctx.reply('Cannot retrieve the image file. Please try again.') - ctx.transient.analytics.actualResponseTime = now() - return + // let filePath = '' + let imgList = [] + if (photo) { + const fileId = photo?.pop()?.file_id // with pop() get full image quality + if (!fileId) { + await ctx.reply('Cannot retrieve the image file. Please try again.') + ctx.transient.analytics.actualResponseTime = now() + return + } + const file = await ctx.api.getFile(fileId) + imgList.push(`${config.openAi.dalle.telegramFileUrl}${config.telegramBotAuthToken}/${file.file_path}`) + } else { + imgList = photoUrl ?? [] } - const file = await ctx.api.getFile(fileId) - const filePath = `${config.openAi.dalle.telegramFileUrl}${config.telegramBotAuthToken}/${file.file_path}` const msgId = ( await ctx.reply('...', { message_thread_id: @@ -627,20 +652,8 @@ export class OpenAIBot implements PayableBot { ctx.message?.reply_to_message?.message_thread_id }) ).message_id - const messages = [ - { - role: 'user', - content: [ - { type: 'text', text: prompt }, - { - type: 'image_url', - image_url: { url: filePath } - } - ] - } - ] const model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW - const completion = await streamChatVisionCompletion(messages, ctx, model, prompt ?? '', filePath, msgId, true) + const completion = await streamChatVisionCompletion(ctx, model, prompt ?? '', imgList, msgId, true) if (completion) { ctx.transient.analytics.sessionState = RequestState.Success ctx.transient.analytics.actualResponseTime = now() diff --git a/src/modules/types.ts b/src/modules/types.ts index 5f8b0af3..ef9073e4 100644 --- a/src/modules/types.ts +++ b/src/modules/types.ts @@ -57,6 +57,7 @@ export interface ImageRequest { command?: 'dalle' | 'alter' | 'vision' prompt?: string photo?: PhotoSize[] | undefined + photoUrl?: string[] } export interface ChatGptSessionData { model: string