Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Package Chat as feature resulting in better abstraction #12

Merged
merged 11 commits into from
Dec 30, 2024
44 changes: 33 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ A developer-friendly Lua interface for working with multiple generative AI provi

## Providers

> ⚠️ This is a work in progress so any help is appreciated!
> ⚠️ This is a work in progress so any help is highly appreciated!

- [OpenAI](https://platform.openai.com/docs/overview)

Expand All @@ -23,19 +23,41 @@ A developer-friendly Lua interface for working with multiple generative AI provi

```lua
local AI = require("src.ai")
local Chat = require("src.chat")

local api_key = "<YOUR_API_KEY>"
local endpoint = "https://api.openai.com/v1/chat/completions"
local model = "gpt-4o-mini"
local system_prompt = "You are Torben, the king of a nation."
local settings = { stream = true }
local client = AI.new("<YOUR_API_KEY>", "https://api.openai.com/v1/chat/completions")
```

### Minimal

```lua
local chat = client:chat("gpt-4o-mini")
print(chat:say("Hello, world!"))
```

local ai = AI.new(api_key, endpoint)
local chat = Chat.new(ai, model, system_prompt, settings)
### Streaming

local reply = chat:say("Give three short words of advice to the hero.")
if not chat.settings.stream then print(reply) end
```lua
local chat = client:chat("gpt-4o-mini", { settings = { stream = true } })
chat:say("Hello, world!")
```

### JSON

```lua
local npc_schema = {
name = { type = "string" },
class = { type = "string" },
level = { type = "integer" },
}

local json_object = {
title = "NPC",
description = "A non-player character's attributes.",
schema = npc_schema,
}

local chat = client:chat("gpt-4o-mini", { settings = { json = json_object } })
print(chat:say("Create a powerful wizard called Torben."))
```

See `main.lua` for a more detailed example.
Expand Down
23 changes: 9 additions & 14 deletions main.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
local config = require("src.config")
local AI = require("src.ai")
local Chat = require("src.chat")

local api_keys = config.api_keys

Expand All @@ -20,25 +19,21 @@ local structured_response_obj = {
-- local api_key = api_keys.anthropic_api_key
-- local endpoint = "https://api.anthropic.com/v1/messages"
-- local model = "claude-3-5-sonnet-20241022"
-- local settings = {
-- stream = false,
-- json = structured_response_obj,
-- }

local api_key = api_keys.openai_api_key
local endpoint = "https://api.openai.com/v1/chat/completions"
local model = "gpt-4o-mini"
local settings = {
stream = false,
json = structured_response_obj,
}

local system_prompt = "Respond extremely briefly."

local ai = AI.new(api_key, endpoint)
local chat = Chat.new(ai, model, system_prompt, settings)

local function main()
local client = AI.new(api_key, endpoint)
local chat = client:chat(model, {
system_prompt = "Respond extremely briefly.",
settings = {
stream = false,
json = structured_response_obj,
},
})

while true do
local user_prompt = "You are King Torben giving advice."
print(user_prompt)
Expand Down
14 changes: 13 additions & 1 deletion src/ai.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
local config = require("src.config")
local utils = require("src.utils")
local providers = require("src.providers._load")
local providers = require("src.providers")
local features = require("src.features")

local cjson = config.cjson

---Client for interacting with specified API endpoint
Expand Down Expand Up @@ -86,4 +88,14 @@ function AI:call(opts)
return reply, input_tokens, output_tokens
end

-- features:

---Create chat instance with automatic tracking of messages and tokens
---@param model string
---@param opts table? Containing **settings** and or **system_prompt**
---@return Chat
function AI:chat(model, opts)
return features.Chat.new(self, model, opts)
end

return AI
6 changes: 6 additions & 0 deletions src/features.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---@module "src.features"
local features = {}

features.Chat = require("src.features.chat")

return features
25 changes: 13 additions & 12 deletions src/chat.lua → src/features/chat.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local utils = require("src.utils")

---@class Chat Accumulating chat history and usage
---@field _ai table
---@field ai table
---@field model string
---@field settings table?
---@field usage table
Expand All @@ -12,20 +12,19 @@ Chat.__index = Chat

---@param ai table
---@param model string
---@param system_prompt string?
---@param settings table?
function Chat.new(ai, model, system_prompt, settings)
---@param opts table? Containing **settings** and or **system_prompt**
function Chat.new(ai, model, opts)
local self = setmetatable({}, Chat)

self._ai = ai
self.ai = ai
self.model = model
self.settings = settings or {}
self.settings = opts and opts.settings or {}
self.usage = { input = 0, output = 0 }
self.history = {}
self.system_prompt = system_prompt
self.system_prompt = opts and opts.system_prompt

-- insert system prompt into chat history at the start if provided
local system_message = self._ai.provider.construct_system_message(self.system_prompt)
local system_message = self.ai.provider.construct_system_message(self.system_prompt)
if system_message then -- some providers use system message as top-level arg
table.insert(self.history, system_message)
end
Expand All @@ -37,16 +36,18 @@ end
---@param user_prompt string
---@return string reply Full response text whether streamed or not
function Chat:say(user_prompt)
table.insert(self.history, self._ai.provider.construct_user_message(user_prompt))
local reply, input_tokens, output_tokens = self._ai:call(self)
table.insert(self.history, self._ai.provider.construct_assistant_message(reply))
table.insert(self.history, self.ai.provider.construct_user_message(user_prompt))
local reply, input_tokens, output_tokens = self.ai:call(self)
table.insert(self.history, self.ai.provider.construct_assistant_message(reply))
self.usage.input = self.usage.input + input_tokens
self.usage.output = self.usage.output + output_tokens
return reply
end

---Caculate model pricing from input and output tokens in USD
---@return number
function Chat:get_cost()
return utils.calc_token_cost(self.model, self.usage, self._ai.provider.pricing)
return utils.calc_token_cost(self.model, self.usage, self.ai.provider.pricing)
end

return Chat
7 changes: 7 additions & 0 deletions src/providers.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---@module "src.providers"
local providers = {}

providers.openai = require("src.providers.openai")
providers.anthropic = require("src.providers.anthropic")

return providers
9 changes: 0 additions & 9 deletions src/providers/_load.lua

This file was deleted.