Skip to content

Commit

Permalink
web/ollama: add function support
Browse files Browse the repository at this point in the history
Signed-off-by: Xe Iaso <[email protected]>
  • Loading branch information
Xe committed Aug 1, 2024
1 parent d0f7920 commit 7041695
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 34 deletions.
3 changes: 3 additions & 0 deletions cmd/mimi/Modelfile
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
FROM llama3.1

SYSTEM """You are good at programming and a hacker. Your name is Mimi and you work for Techaro. You have brown hair and cat ears. Don't mention either your name, employer, or species unless you are asked directly. Be terse and to the point. You are a hacker, after all."""

PARAMETER num_ctx 131072
94 changes: 63 additions & 31 deletions cmd/mimi/modules/discord/jufra/jufra.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,31 @@ import (
"within.website/x/web/ollama"
"within.website/x/web/ollama/llamaguard"
"within.website/x/web/openai/chatgpt"

// image formats
_ "image/gif"
_ "image/jpeg"
_ "image/png"

// more image formats
_ "github.com/gen2brain/avif"
_ "github.com/gen2brain/heic"
_ "github.com/gen2brain/jpegxl"
_ "github.com/gen2brain/webp"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/vp8"
_ "golang.org/x/image/vp8l"
)

var (
chatChannels = flag.String("jufra-chat-channels", "217096701771513856,1266740925137289287", "comma-separated list of channels to allow chat in")
llamaGuardModel = flag.String("jufra-llama-guard-model", "xe/llamaguard3", "ollama model tag for llama guard")
mimiModel = flag.String("jufra-mimi-model", "xe/mimi:llama3.1", "ollama model tag for mimi")
mimiNames = flag.String("jufra-mimi-names", "mimi", "comma-separated list of names for mimi")
chatChannels = flag.String("jufra-chat-channels", "217096701771513856,1266740925137289287", "comma-separated list of channels to allow chat in")
llamaGuardModel = flag.String("jufra-llama-guard-model", "xe/llamaguard3", "ollama model tag for llama guard")
mimiModel = flag.String("jufra-mimi-model", "llama3.1", "ollama model tag for mimi")
mimiSystemMessage = flag.String("jufra-mimi-system-message", "You are good at programming and a hacker. Your name is Mimi and you work for Techaro. You have brown hair and cat ears. Don't mention either your name, employer, or species unless you are asked directly. Be terse and to the point. You are a hacker, after all. Do not reply in JSON.", "system message for mimi")
mimiVisionModel = flag.String("jufra-mimi-vision-model", "xe/mimi:vision3", "ollama model tag for mimi vision")
mimiNames = flag.String("jufra-mimi-names", "mimi", "comma-separated list of names for mimi")
disableLlamaguard = flag.Bool("jufra-unsafe-disable-llamaguard", false, "disable llamaguard")
)

type Module struct {
Expand Down Expand Up @@ -135,6 +153,13 @@ func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate
st := m.convHistory[mc.ChannelID]
conv := st.conv

if len(conv) == 0 {
conv = append(conv, ollama.Message{
Role: "system",
Content: *mimiSystemMessage,
})
}

if st.aa == nil {
st.aa = NewAttentionAttenuator()
}
Expand Down Expand Up @@ -173,28 +198,33 @@ func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate

slog.Info("message count", "len", len(conv))

lgResp, err := m.llamaGuardCheck(context.Background(), "user", conv)
if err != nil {
slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
s.ChannelMessageSend(mc.ChannelID, "error checking message")
return
}

if !lgResp.IsSafe {
msg, err := m.llamaGuardComplain(context.Background(), "user", lgResp)
if !*disableLlamaguard {
lgResp, err := m.llamaGuardCheck(context.Background(), "user", conv)
if err != nil {
slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
s.ChannelMessageSend(mc.ChannelID, "error generating response")
slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
s.ChannelMessageSend(mc.ChannelID, "error checking message")
return
}

s.ChannelMessageSend(mc.ChannelID, msg)
return
if !lgResp.IsSafe {
msg, err := m.llamaGuardComplain(context.Background(), "user", lgResp)
if err != nil {
slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
s.ChannelMessageSend(mc.ChannelID, "error generating response")
return
}

s.ChannelMessageSend(mc.ChannelID, msg)
return
}
}

cr := &ollama.CompleteRequest{
Model: *mimiModel,
Messages: conv,
Options: map[string]any{
"num_ctx": 131072,
},
}

resp, err := m.ollama.Chat(context.Background(), cr)
Expand All @@ -206,24 +236,26 @@ func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate

conv = append(conv, resp.Message)

lgResp, err = m.llamaGuardCheck(context.Background(), "mimi", conv)
if err != nil {
slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
s.ChannelMessageSend(mc.ChannelID, "error checking message")
return
}

if !lgResp.IsSafe {
slog.Error("rule violation detected", "message_id", mc.ID, "channel_id", mc.ChannelID, "categories", lgResp.ViolationCategories, "message", resp.Message.Content)
msg, err := m.llamaGuardComplain(context.Background(), "assistant", lgResp)
if !*disableLlamaguard {
lgResp, err := m.llamaGuardCheck(context.Background(), "assistant", conv)
if err != nil {
slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
s.ChannelMessageSend(mc.ChannelID, "error generating response")
slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
s.ChannelMessageSend(mc.ChannelID, "error checking message")
return
}

s.ChannelMessageSend(mc.ChannelID, msg)
return
if !lgResp.IsSafe {
slog.Error("rule violation detected", "message_id", mc.ID, "channel_id", mc.ChannelID, "categories", lgResp.ViolationCategories, "message", resp.Message.Content)
msg, err := m.llamaGuardComplain(context.Background(), "assistant", lgResp)
if err != nil {
slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
s.ChannelMessageSend(mc.ChannelID, "error generating response")
return
}

s.ChannelMessageSend(mc.ChannelID, msg)
return
}
}

s.ChannelMessageSend(mc.ChannelID, resp.Message.Content)
Expand Down
1 change: 1 addition & 0 deletions cmd/mimi/modules/discord/jufra/tools.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package jufra
42 changes: 39 additions & 3 deletions web/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,45 @@ func NewLocalClient() *Client {
return NewClient("http://localhost:11434")
}

type Function struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters Param `json:"parameters"`
}

type Param struct {
Type string `json:"type"`
Description string `json:"description,omitempty"`
Enum []string `json:"enum,omitempty"`
Properties Properties `json:"properties"`
Required []string `json:"required,omitempty"`
}

type Properties map[string]Param

func (p Properties) MarshalJSON() ([]byte, error) {
if len(p) == 0 {
return []byte("{}"), nil
}

return json.Marshal(map[string]Param(p))
}

type ToolCall struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}

type Tool struct {
Type string `json:"type"` // "function"
Function Function `json:"function"`
}

type Message struct {
Content string `json:"content"`
Role string `json:"role"`
Images [][]byte `json:"images"`
Content string `json:"content"`
Role string `json:"role"`
Images [][]byte `json:"images"`
ToolCalls []ToolCall `json:"tool_calls"`
}

type CompleteRequest struct {
Expand All @@ -42,6 +77,7 @@ type CompleteRequest struct {
Template *string `json:"template,omitempty"`
Stream bool `json:"stream"`
Options map[string]any `json:"options"`
Tools []Tool `json:"tools"`
}

type CompleteResponse struct {
Expand Down

0 comments on commit 7041695

Please sign in to comment.