Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add prompt command and refactor context logic #372

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -516,12 +516,24 @@ const logErrorHandler = (ex: any): void => {
bot.command('new', async (ctx) => {
writeCommandLog(ctx as OnMessageContext).catch(logErrorHandler)
await openAiBot.onStop(ctx as OnMessageContext)
await claudeBot.onStop(ctx as OnMessageContext) // any Bot with 'llms' as sessionKey works.
return await ctx.reply('Chat history reseted', {
parse_mode: 'Markdown',
message_thread_id: ctx.message?.message_thread_id
})
})

bot.command('prompt', async (ctx) => {
const context = ctx.match
if (context) {
ctx.session.currentPrompt = context
}
await ctx.reply(`Prompt set to: _${ctx.session.currentPrompt}_`, {
parse_mode: 'Markdown',
message_thread_id: ctx.message?.message_thread_id
})
})

bot.command('more', async (ctx) => {
writeCommandLog(ctx as OnMessageContext).catch(logErrorHandler)
return await ctx.reply(commandsHelpText.more, {
Expand Down
1 change: 1 addition & 0 deletions src/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export function createInitialSessionData (): BotSessionData {
isVoiceForwardingEnabled: config.voiceMemo.isVoiceForwardingEnabled
},
currentModel: LlmModelsEnum.GPT_4O,
currentPrompt: config.openAi.chatGpt.chatCompletionContext,
lastBroadcast: ''
}
}
Expand Down
11 changes: 7 additions & 4 deletions src/modules/llms/api/athropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ const API_ENDPOINT = config.llms.apiEndpoint // 'http://127.0.0.1:5000' // confi
export const anthropicCompletion = async (
conversation: ChatConversation[],
model = LlmModelsEnum.CLAUDE_3_OPUS,
ctx: OnMessageContext | OnCallBackQueryData,
parameters?: ModelParameters
): Promise<LlmCompletion> => {
logger.info(`Handling ${model} completion`)
parameters = parameters ?? {
system: config.openAi.chatGpt.chatCompletionContext,
system: ctx.session.currentPrompt,
max_tokens: +config.openAi.chatGpt.maxTokens
}
const data = {
Expand Down Expand Up @@ -68,11 +69,12 @@ export const anthropicCompletion = async (
export const xaiCompletion = async (
conversation: ChatConversation[],
model = LlmModelsEnum.GROK,
ctx: OnMessageContext | OnCallBackQueryData,
parameters?: ModelParameters
): Promise<LlmCompletion> => {
logger.info(`Handling ${model} completion`)
parameters = parameters ?? {
system: config.openAi.chatGpt.chatCompletionContext,
system: ctx.session.currentPrompt,
max_tokens: +config.openAi.chatGpt.maxTokens
}
const data = {
Expand Down Expand Up @@ -119,7 +121,7 @@ export const anthropicStreamCompletion = async (
): Promise<LlmCompletion> => {
logger.info(`Handling ${model} stream completion`)
parameters = parameters ?? {
system: config.openAi.chatGpt.chatCompletionContext,
system: ctx.session.currentPrompt,
max_tokens: +config.openAi.chatGpt.maxTokens
}
const data = {
Expand Down Expand Up @@ -217,11 +219,12 @@ export const anthropicStreamCompletion = async (
export const toolsChatCompletion = async (
conversation: ChatConversation[],
model = LlmModelsEnum.CLAUDE_3_OPUS,
ctx: OnMessageContext | OnCallBackQueryData,
parameters?: ModelParameters
): Promise<LlmCompletion> => {
logger.info(`Handling ${model} completion`)
parameters = parameters ?? {
system: config.openAi.chatGpt.chatCompletionContext,
system: ctx.session.currentPrompt,
max_tokens: +config.openAi.chatGpt.maxTokens
}
const input = {
Expand Down
15 changes: 8 additions & 7 deletions src/modules/llms/api/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,26 @@ export async function alterGeneratedImg (

type ConversationOutput = Omit<ChatConversation, 'timestamp' | 'model' | 'id' | 'author' | 'numSubAgents'>

const prepareConversation = (conversation: ChatConversation[], model: string): ConversationOutput[] => {
const prepareConversation = (conversation: ChatConversation[], model: string, ctx: OnMessageContext | OnCallBackQueryData): ConversationOutput[] => {
const messages = conversation.filter(c => c.model === model).map(m => { return { content: m.content, role: m.role } })
if (messages.length !== 1 || model === LlmModelsEnum.O1) {
return messages
}
const systemMessage = {
role: 'system',
content: config.openAi.chatGpt.chatCompletionContext
content: ctx.session.currentPrompt
}
return [systemMessage, ...messages]
}

export async function chatCompletion (
conversation: ChatConversation[],
model = config.openAi.chatGpt.model,
ctx: OnMessageContext | OnCallBackQueryData,
limitTokens = true,
parameters?: ModelParameters
): Promise<LlmCompletion> {
const messages = prepareConversation(conversation, model)
const messages = prepareConversation(conversation, model, ctx)
parameters = parameters ?? {
max_completion_tokens: config.openAi.chatGpt.maxTokens,
temperature: config.openAi.dalle.completions.temperature
Expand Down Expand Up @@ -139,15 +140,15 @@ export async function chatCompletion (

export const streamChatCompletion = async (
conversation: ChatConversation[],
ctx: OnMessageContext | OnCallBackQueryData,
model = LlmModelsEnum.GPT_4,
ctx: OnMessageContext | OnCallBackQueryData,
msgId: number,
limitTokens = true,
parameters?: ModelParameters
): Promise<LlmCompletion> => {
let completion = ''
let wordCountMinimum = 2
const messages = prepareConversation(conversation, model)
const messages = prepareConversation(conversation, model, ctx)
parameters = parameters ?? {
max_completion_tokens: config.openAi.chatGpt.maxTokens,
temperature: config.openAi.dalle.completions.temperature || 0.8
Expand Down Expand Up @@ -322,10 +323,10 @@ export const streamChatVisionCompletion = async (
}
}

export async function improvePrompt (promptText: string, model: string): Promise<string> {
export async function improvePrompt (promptText: string, model: string, ctx: OnMessageContext | OnCallBackQueryData): Promise<string> {
const prompt = `Improve this picture description using max 100 words and don't add additional text to the image: ${promptText} `
const conversation = [{ role: 'user', content: prompt, timestamp: Date.now() }]
const response = await chatCompletion(conversation, model)
const response = await chatCompletion(conversation, model, ctx)
return response.completion?.content as string ?? ''
}

Expand Down
5 changes: 3 additions & 2 deletions src/modules/llms/api/vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ const logger = pino({
export const vertexCompletion = async (
conversation: ChatConversation[],
model = config.llms.model,
ctx: OnMessageContext | OnCallBackQueryData,
parameters?: ModelParameters
): Promise<LlmCompletion> => {
const data = {
model,
system: ctx.session.currentPrompt,
stream: false,
messages: conversation.filter(c => c.model === model)
.map((msg) => {
Expand Down Expand Up @@ -71,10 +73,9 @@ export const vertexStreamCompletion = async (
parameters?: ModelParameters
): Promise<LlmCompletion> => {
parameters = parameters ?? {
system: config.openAi.chatGpt.chatCompletionContext,
system: ctx.session.currentPrompt,
max_tokens: +config.openAi.chatGpt.maxTokens
}

const data = {
model,
stream: true, // Set stream to true to receive the completion as a stream
Expand Down
11 changes: 9 additions & 2 deletions src/modules/llms/claudeBot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ export class ClaudeBot extends LlmsBase {
msgId: number,
limitTokens: boolean,
parameters?: ModelParameters): Promise<LlmCompletion> {
if (parameters) {
parameters.system = ctx.session.currentPrompt
}
return await anthropicStreamCompletion(
conversation,
model,
Expand All @@ -55,13 +58,17 @@ export class ClaudeBot extends LlmsBase {
async chatCompletion (
conversation: ChatConversation[],
model: ModelVersion,
ctx: OnMessageContext | OnCallBackQueryData,
hasTools: boolean,
parameters?: ModelParameters
): Promise<LlmCompletion> {
if (parameters) {
parameters.system = ctx.session.currentPrompt
}
if (hasTools) {
return await toolsChatCompletion(conversation, model, parameters)
return await toolsChatCompletion(conversation, model, ctx, parameters)
}
return await anthropicCompletion(conversation, model, parameters)
return await anthropicCompletion(conversation, model, ctx, parameters)
}

public async onEvent (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> {
Expand Down
7 changes: 4 additions & 3 deletions src/modules/llms/dalleBot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,19 @@ export class DalleBot extends LlmsBase {
): Promise<LlmCompletion> {
return await streamChatCompletion(
conversation,
ctx,
model,
ctx,
msgId,
true // telegram messages has a character limit
)
}

async chatCompletion (
conversation: ChatConversation[],
model: ModelVersion
model: ModelVersion,
ctx: OnMessageContext | OnCallBackQueryData
): Promise<LlmCompletion> {
return await chatCompletion(conversation, model)
return await chatCompletion(conversation, model, ctx)
}

hasPrefix (prompt: string): string {
Expand Down
5 changes: 3 additions & 2 deletions src/modules/llms/llmsBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ export abstract class LlmsBase implements PayableBot {
protected abstract chatCompletion (
conversation: ChatConversation[],
model: ModelVersion,
ctx: OnMessageContext | OnCallBackQueryData,
usesTools: boolean,
parameters?: ModelParameters
): Promise<LlmCompletion>
Expand Down Expand Up @@ -373,7 +374,7 @@ export abstract class LlmsBase implements PayableBot {
}
} else {
const parameters = this.modelManager.getModelParameters(model)
const response = await this.chatCompletion(conversation, model, usesTools, parameters)
const response = await this.chatCompletion(conversation, model, ctx, usesTools, parameters)
conversation.push({
role: 'assistant',
content: response.completion?.content ?? '',
Expand Down Expand Up @@ -406,7 +407,7 @@ export abstract class LlmsBase implements PayableBot {
).message_id
ctx.chatAction = 'typing'
const parameters = this.modelManager.getModelParameters(model)
const response = await this.chatCompletion(conversation, model, usesTools, parameters)
const response = await this.chatCompletion(conversation, model, ctx, usesTools, parameters)
if (response.completion) {
if (model === this.modelsEnum.O1) {
const msgs = splitTelegramMessage(response.completion.content as string)
Expand Down
5 changes: 3 additions & 2 deletions src/modules/llms/openaiBot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ export class OpenAIBot extends LlmsBase {
): Promise<LlmCompletion> {
return await streamChatCompletion(
conversation,
ctx,
model,
ctx,
msgId,
true, // telegram messages has a character limit
parameters
Expand All @@ -84,10 +84,11 @@ export class OpenAIBot extends LlmsBase {
async chatCompletion (
conversation: ChatConversation[],
model: ModelVersion,
ctx: OnMessageContext | OnCallBackQueryData,
usesTools: boolean,
parameters?: ModelParameters
): Promise<LlmCompletion> {
return await chatCompletion(conversation, model, model !== this.modelsEnum.O1, parameters) // limitTokens doesn't apply for o1-preview
return await chatCompletion(conversation, model, ctx, model !== this.modelsEnum.O1, parameters) // limitTokens doesn't apply for o1-preview
}

hasPrefix (prompt: string): string {
Expand Down
8 changes: 4 additions & 4 deletions src/modules/llms/utils/llmsData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,25 +223,25 @@ export const llmData: LLMData = {
},
claude: {
defaultParameters: {
system: config.openAi.chatGpt.chatCompletionContext,
// system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: +config.openAi.chatGpt.maxTokens
}
},
xai: {
defaultParameters: {
system: config.openAi.chatGpt.chatCompletionContext,
// system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: +config.openAi.chatGpt.maxTokens
}
},
vertex: {
defaultParameters: {
system: config.openAi.chatGpt.chatCompletionContext,
// system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: +config.openAi.chatGpt.maxTokens
}
},
luma: {
defaultParameters: {
system: config.openAi.chatGpt.chatCompletionContext,
// system: config.openAi.chatGpt.chatCompletionContext,
max_tokens: +config.openAi.chatGpt.maxTokens
}
}
Expand Down
9 changes: 8 additions & 1 deletion src/modules/llms/vertexBot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ export class VertexBot extends LlmsBase {
msgId: number,
limitTokens: boolean,
parameters?: ModelParameters): Promise<LlmCompletion> {
if (parameters) {
parameters.system = ctx.session.currentPrompt
}
return await vertexStreamCompletion(conversation,
model,
ctx,
Expand All @@ -58,10 +61,14 @@ export class VertexBot extends LlmsBase {
async chatCompletion (
conversation: ChatConversation[],
model: ModelVersion,
ctx: OnMessageContext | OnCallBackQueryData,
usesTools: boolean,
parameters?: ModelParameters
): Promise<LlmCompletion> {
return await vertexCompletion(conversation, model, parameters)
if (parameters) {
parameters.system = ctx.session.currentPrompt
}
return await vertexCompletion(conversation, model, ctx, parameters)
}

public async onEvent (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> {
Expand Down
6 changes: 5 additions & 1 deletion src/modules/llms/xaiBot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@ export class XaiBot extends LlmsBase {
async chatCompletion (
conversation: ChatConversation[],
model: ModelVersion,
ctx: OnMessageContext | OnCallBackQueryData,
hasTools: boolean,
parameters?: ModelParameters
): Promise<LlmCompletion> {
return await xaiCompletion(conversation, model, parameters)
if (parameters) {
parameters.system = ctx.session.currentPrompt
}
return await xaiCompletion(conversation, model, ctx, parameters)
}

public async onEvent (ctx: OnMessageContext | OnCallBackQueryData): Promise<void> {
Expand Down
1 change: 1 addition & 0 deletions src/modules/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ export interface BotSessionData {
subagents: SubagentSessionData
dalle: ImageGenSessionData
currentModel: ModelVersion
currentPrompt: string
lastBroadcast: string
voiceMemo: VoiceMemoSessionData
}
Expand Down
2 changes: 1 addition & 1 deletion src/modules/voice-to-voice-gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export class VoiceToVoiceGPTBot implements PayableBot {
fs.rmSync(filename)

const conversation = [{ role: 'user', content: resultText, timestamp: Date.now() }]
const response = await chatCompletion(conversation, LlmModelsEnum.GPT_35_TURBO)
const response = await chatCompletion(conversation, LlmModelsEnum.GPT_35_TURBO, ctx)

const voiceResult = await generateVoiceFromText(response.completion?.content as string)
// const voiceResult = await gcTextToSpeedClient.ssmlTextToSpeech({ text: response.completion, ssmlGender: 'MALE', languageCode: 'en-US' })
Expand Down
Loading