Skip to content

Commit 6c9404a

Browse files
authored
Merge pull request #4 from emilrueh/dev
Abstract conversation logic and modularize providers
2 parents 1fddbfe + 2a0496f commit 6c9404a

13 files changed

+329
-260
lines changed

README.md

+53-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,55 @@
1-
A Lua wrapper for generative AI APIs.
1+
# Unified Lua Interface for Generative AI
22

3-
Currently implemented are OpenAI, Anthropic, Gemini, and Groq for Open-Source models.
3+
A developer-friendly Lua interface for working with multiple generative AI providers, abstracting away provider-specific payload structures and response parsing so you can easily switch between various models and providers without rewriting any code.
44

5-
This project is under construction and any help is appreciated.
5+
## Providers
6+
7+
- [OpenAI](https://platform.openai.com/docs/overview)
8+
9+
- [Anthropic](https://docs.anthropic.com/en/home)
10+
11+
## Features
12+
13+
- Easily switch between AI chat model providers.
14+
- Pass in prompts and get replies without the provider complexity.
15+
- Easily integrate new models and adjust settings.
16+
- Work directly with the `src.ai` client for more granular control.
17+
18+
## Usage
19+
20+
```lua
21+
local AI = require("src.ai")
22+
local Chat = require("src.chat")
23+
24+
local api_key = "<YOUR_API_KEY>"
25+
local endpoint = "https://api.openai.com/v1/chat/completions"
26+
local model = "gpt-4o-mini"
27+
local system_prompt = "You are the king of a nation."
28+
29+
local ai = AI.new(api_key, endpoint)
30+
local chat = Chat.new(ai, model, system_prompt)
31+
32+
local reply = chat:say("Give three short words of advice to the hero.")
33+
print(reply)
34+
```
35+
36+
See `main.lua` for a more detailed example.
37+
38+
### Dependencies
39+
40+
- [lua-https](https://github.com/love2d/lua-https)
41+
42+
- [lua-cjson](https://github.com/openresty/lua-cjson)
43+
44+
## Status
45+
46+
⚠️ This is a work in progress so any help is appreciated!
47+
48+
### Future
49+
50+
1. Streaming responses
51+
2. Error handling
52+
3. Token cost tracking
53+
4. Gemini and open-source model integration
54+
5. Image models
55+
6. Audio models

main.lua

+14-25
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
1-
--
2-
3-
if os.getenv("LOCAL_LUA_DEBUGGER_VSCODE") == "1" then
4-
require("lldebugger").start()
5-
end
6-
7-
--
8-
91
local config = require("src.config")
10-
local OpenAI = require("src.ai.openai")
2+
local AI = require("src.ai")
3+
local Chat = require("src.chat")
114

125
local api_keys = config.api_keys
136

14-
local client = OpenAI.new(api_keys.openai_api_key, "https://api.openai.com/v1/chat/completions")
7+
-- local api_key = api_keys.anthropic_api_key
8+
-- local endpoint = "https://api.anthropic.com/v1/messages"
9+
-- local model = "claude-3-5-haiku-20241022"
10+
11+
local api_key = api_keys.openai_api_key
12+
local endpoint = "https://api.openai.com/v1/chat/completions"
13+
local model = "gpt-4o-mini"
1514

16-
---@param model string
17-
local function conversation(model)
18-
local system_prompt = "Respond extremely briefly and concise."
15+
local system_prompt = "Respond extremely briefly."
1916

20-
local messages = {}
21-
table.insert(messages, { role = "system", content = system_prompt })
17+
local ai = AI.new(api_key, endpoint)
18+
local chat = Chat.new(ai, model, system_prompt)
2219

20+
local function main()
2321
while true do
2422
local user_prompt = io.read()
2523
print()
@@ -28,20 +26,11 @@ local function conversation(model)
2826
break
2927
end
3028

31-
table.insert(messages, { role = "user", content = user_prompt })
32-
33-
-- api call
34-
local reply, input_tokens, output_tokens = client:call(messages, model)
35-
36-
table.insert(messages, { role = "assistant", content = reply })
29+
local reply = chat:say(user_prompt)
3730

3831
print(reply)
3932
print()
4033
end
4134
end
4235

43-
local function main()
44-
conversation("gpt-4o-mini")
45-
end
46-
4736
main()

src/ai.lua

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
local config = require("src.config")
2+
local utils = require("src.utils")
3+
4+
local providers = require("src.providers._load")
5+
6+
local cjson = config.cjson
7+
8+
---AI API client for interacting with specified endpoint
9+
---@class AI
10+
---@field _api_key string
11+
---@field _endpoint string
12+
---@field provider table|nil
13+
---@field _determine_provider function
14+
local AI = {}
15+
AI.__index = AI
16+
17+
function AI.new(api_key, endpoint)
18+
local self = setmetatable({}, AI)
19+
20+
self._api_key = api_key
21+
self._endpoint = endpoint
22+
23+
-- Determine AI provider module to load
24+
self.provider = self:_determine_provider(providers)
25+
assert(self.provider, "AI provider could not be determined from provided endpoint")
26+
27+
return self
28+
end
29+
30+
---Check endpoint for occurance of ai provider name
31+
---@param providers table
32+
---@return table|nil provider_module
33+
function AI:_determine_provider(providers)
34+
for provider_name, provider_module in pairs(providers) do
35+
if self._endpoint:find(provider_name) then
36+
return provider_module
37+
end
38+
end
39+
40+
return nil -- default if no provider found
41+
end
42+
43+
---OpenAI API call to specified model
44+
---@param opts table
45+
---@return string reply
46+
function AI:call(opts)
47+
local headers = self.provider.construct_headers(self._api_key)
48+
local payload = cjson.encode(self.provider.construct_payload(opts))
49+
local response = cjson.decode(utils.send_request(self._endpoint, payload, "POST", headers))
50+
local reply, input_tokens, output_tokens = self.provider.extract_response_data(response)
51+
52+
return reply
53+
end
54+
55+
return AI

src/ai/anthropic.lua

-74
This file was deleted.

src/ai/google.lua

-51
This file was deleted.

src/ai/groq.lua

-41
This file was deleted.

src/ai/openai.lua

-52
This file was deleted.

0 commit comments

Comments
 (0)