Skip to content

Commit

Permalink
Merge pull request #844 from trheyi/main
Browse files Browse the repository at this point in the history
Refactor assistant types and tools support
  • Loading branch information
trheyi authored Feb 1, 2025
2 parents c5a9d8e + 455e39f commit 465a3fc
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 92 deletions.
18 changes: 12 additions & 6 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,13 @@ func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []chatMessag

// contents
fmt.Println("---contents ---")
utils.Dump(contents)
if contents.Data != nil {
fmt.Println("---contents.Data ---")
for _, content := range contents.Data {
fmt.Println(content.Map())
}
fmt.Println("---contents.Data end ---")
}
fmt.Println("---contents end ---")

// Add mentions
Expand All @@ -465,9 +471,9 @@ func (ast *Assistant) withOptions(options map[string]interface{}) map[string]int
}
}

// Add functions
if ast.Functions != nil && len(ast.Functions) > 0 {
options["tools"] = ast.Functions
// Add tools
if ast.Tools != nil && len(ast.Tools) > 0 {
options["tools"] = ast.Tools
if options["tool_choice"] == nil {
options["tool_choice"] = "auto"
}
Expand Down Expand Up @@ -554,8 +560,8 @@ func (ast *Assistant) requestMessages(ctx context.Context, messages []chatMessag

for index, message := range messages {

// Ignore the function call message
if message.Type == "function" {
// Ignore the tool call message
if message.Type == "tool_calls" {
continue
}

Expand Down
26 changes: 25 additions & 1 deletion neo/assistant/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (ast *Assistant) Map() map[string]interface{} {
"description": ast.Description,
"options": ast.Options,
"prompts": ast.Prompts,
"functions": ast.Functions,
"tools": ast.Tools,
"tags": ast.Tags,
"mentionable": ast.Mentionable,
"automated": ast.Automated,
Expand Down Expand Up @@ -199,6 +199,12 @@ func (ast *Assistant) Clone() *Assistant {
copy(clone.Prompts, ast.Prompts)
}

// Deep copy tools
if ast.Tools != nil {
clone.Tools = make([]Tool, len(ast.Tools))
copy(clone.Tools, ast.Tools)
}

// Deep copy flows
if ast.Flows != nil {
clone.Flows = make([]map[string]interface{}, len(ast.Flows))
Expand Down Expand Up @@ -232,6 +238,24 @@ func (ast *Assistant) Update(data map[string]interface{}) error {
if v, ok := data["connector"].(string); ok {
ast.Connector = v
}

if v, has := data["tools"]; has {
switch tools := v.(type) {
case []Tool:
ast.Tools = tools
default:
raw, err := jsoniter.Marshal(tools)
if err != nil {
return err
}
ast.Tools = []Tool{}
err = jsoniter.Unmarshal(raw, &ast.Tools)
if err != nil {
return err
}
}
}

if v, ok := data["type"].(string); ok {
ast.Type = v
}
Expand Down
91 changes: 48 additions & 43 deletions neo/assistant/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,16 +276,15 @@ func LoadPath(path string) (*Assistant, error) {
data["updated_at"] = max(updatedAt, ts)
}

// load functions
functionsfile := filepath.Join(path, "functions.json")
if has, _ := app.Exists(functionsfile); has {
functions, ts, err := loadFunctions(functionsfile)
// load tools
toolsfile := filepath.Join(path, "tools.yao")
if has, _ := app.Exists(toolsfile); has {
tools, ts, err := loadTools(toolsfile)
if err != nil {
return nil, err
}
data["functions"] = functions
data["tools"] = tools
updatedAt = max(updatedAt, ts)
data["updated_at"] = updatedAt
}

// load flow
Expand Down Expand Up @@ -440,22 +439,24 @@ func loadMap(data map[string]interface{}) (*Assistant, error) {
assistant.Prompts = prompts
}

// functions
if funcs, has := data["functions"]; has {
switch vv := funcs.(type) {
case []Function:
assistant.Functions = vv
// tools
if tools, has := data["tools"]; has {
switch vv := tools.(type) {
case []Tool:
assistant.Tools = vv

default:
raw, err := jsoniter.Marshal(vv)
raw, err := jsoniter.Marshal(tools)
if err != nil {
return nil, err
return nil, fmt.Errorf("tools format error %s", err.Error())
}
var functions []Function
err = jsoniter.Unmarshal(raw, &functions)

var tools []Tool
err = jsoniter.Unmarshal(raw, &tools)
if err != nil {
return nil, err
return nil, fmt.Errorf("tools format error %s", err.Error())
}
assistant.Functions = functions
assistant.Tools = tools
}
}

Expand Down Expand Up @@ -501,32 +502,6 @@ func loadMap(data map[string]interface{}) (*Assistant, error) {
return assistant, nil
}

func loadFunctions(file string) ([]Function, int64, error) {

app, err := fs.Get("app")
if err != nil {
return nil, 0, err
}

ts, err := app.ModTime(file)
if err != nil {
return nil, 0, err
}

raw, err := app.ReadFile(file)
if err != nil {
return nil, 0, err
}

var functions []Function
err = jsoniter.Unmarshal(raw, &functions)
if err != nil {
return nil, 0, err
}

return functions, ts.UnixNano(), nil
}

func loadPrompts(file string, root string) (string, int64, error) {

app, err := fs.Get("app")
Expand Down Expand Up @@ -629,3 +604,33 @@ func (ast *Assistant) initialize() error {

return nil
}

func loadTools(file string) ([]Tool, int64, error) {

app, err := fs.Get("app")
if err != nil {
return nil, 0, err
}

content, err := app.ReadFile(file)
if err != nil {
return nil, 0, err
}

ts, err := app.ModTime(file)
if err != nil {
return nil, 0, err
}

if len(content) == 0 {
return []Tool{}, ts.UnixNano(), nil
}

var tools []Tool
err = jsoniter.Unmarshal(content, &tools)
if err != nil {
return nil, 0, err
}

return tools, ts.UnixNano(), nil
}
6 changes: 3 additions & 3 deletions neo/assistant/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ type Prompt struct {
Name string `json:"name,omitempty"`
}

// Function a function
type Function struct {
// Tool represents a tool
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Expand Down Expand Up @@ -129,7 +129,7 @@ type Assistant struct {
Automated bool `json:"automated,omitempty"` // Whether this assistant is automated
Options map[string]interface{} `json:"options,omitempty"` // AI Options
Prompts []Prompt `json:"prompts,omitempty"` // AI Prompts
Functions []Function `json:"functions,omitempty"` // Assistant Functions
Tools []Tool `json:"tools,omitempty"` // Assistant Tools
Flows []map[string]interface{} `json:"flows,omitempty"` // Assistant Flows
Placeholder *Placeholder `json:"placeholder,omitempty"` // Assistant Placeholder
Script *v8.Script `json:"-" yaml:"-"` // Assistant Script
Expand Down
19 changes: 11 additions & 8 deletions neo/message/contents.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package message

import (
"fmt"

jsoniter "github.com/json-iterator/go"
)

Expand All @@ -21,12 +23,12 @@ type Contents struct {

// Data the data of the content
type Data struct {
Type string `json:"type"` // text, function, error, ...
ID string `json:"id"` // the id of the content
Function string `json:"function"` // the function name
Bytes []byte `json:"bytes"` // the content bytes
Arguments []byte `json:"arguments"` // the function arguments
Props map[string]interface{} `json:"props"` // the props
Type string `json:"type"` // text, function, error, ...
ID string `json:"id"` // the id of the content
Function string `json:"function"` // the function name
Bytes []byte `json:"bytes"` // the content bytes
Arguments []byte `json:"arguments,omitempty"` // the function arguments
Props map[string]interface{} `json:"props"` // the props
}

// NewContents create a new contents
Expand Down Expand Up @@ -159,7 +161,8 @@ func (data *Data) Map() (map[string]interface{}, error) {
v["props"] = data.Props
}

if data.Arguments != nil {
if data.Arguments != nil && len(data.Arguments) > 0 {
fmt.Println("data.Arguments", string(data.Arguments))
var vv interface{} = nil
err := jsoniter.Unmarshal(data.Arguments, &vv)
if err != nil {
Expand Down Expand Up @@ -192,7 +195,7 @@ func (data *Data) MarshalJSON() ([]byte, error) {
v["props"] = data.Props
}

if data.Arguments != nil {
if data.Arguments != nil && len(data.Arguments) > 0 {
var vv interface{} = nil
err := jsoniter.Unmarshal(data.Arguments, &vv)
if err != nil {
Expand Down
35 changes: 32 additions & 3 deletions neo/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ func NewOpenAI(data []byte) *Message {
case strings.Contains(text, `"delta":{`) && strings.Contains(text, `"tool_calls"`):
var toolCalls openai.ToolCalls
if err := jsoniter.Unmarshal(data, &toolCalls); err != nil {
msg.Text = err.Error() + "\n" + string(data)
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
return msg
}

Expand All @@ -205,7 +209,11 @@ func NewOpenAI(data []byte) *Message {
case strings.Contains(text, `"delta":{`) && strings.Contains(text, `"content":`):
var message openai.Message
if err := jsoniter.Unmarshal(data, &message); err != nil {
msg.Text = err.Error() + "\n" + string(data)
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
return msg
}

Expand All @@ -214,14 +222,35 @@ func NewOpenAI(data []byte) *Message {
msg.Text = message.Choices[0].Delta.Content
}

case strings.Index(text, `{"code":`) == 0:
var errorMessage openai.Error
if err := jsoniter.UnmarshalFromString(text, &errorMessage); err != nil {
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
return msg
}
msg.Type = "error"
msg.Text = errorMessage.Message
msg.IsDone = true
break

case strings.Contains(text, `{"error":{`):
var errorMessage openai.ErrorMessage
if err := jsoniter.Unmarshal(data, &errorMessage); err != nil {
msg.Text = err.Error() + "\n" + string(data)
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
return msg
}
msg.Type = "error"
msg.Text = errorMessage.Error.Message
msg.IsDone = true
break

case strings.Contains(text, `[DONE]`):
msg.IsDone = true
Expand Down
10 changes: 5 additions & 5 deletions neo/store/xun.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func (conv *Xun) initAssistantTable() error {
table.JSON("prompts").Null() // assistant prompts
table.JSON("flows").Null() // assistant flows
table.JSON("files").Null() // assistant files
table.JSON("functions").Null() // assistant functions
table.JSON("tools").Null() // assistant tools
table.JSON("tags").Null() // assistant tags
table.Boolean("readonly").SetDefault(false).Index() // assistant readonly
table.JSON("permissions").Null() // assistant permissions
Expand All @@ -259,7 +259,7 @@ func (conv *Xun) initAssistantTable() error {
return err
}

fields := []string{"id", "assistant_id", "type", "name", "avatar", "connector", "description", "path", "sort", "built_in", "placeholder", "options", "prompts", "flows", "files", "functions", "tags", "mentionable", "created_at", "updated_at"}
fields := []string{"id", "assistant_id", "type", "name", "avatar", "connector", "description", "path", "sort", "built_in", "placeholder", "options", "prompts", "flows", "files", "tools", "tags", "mentionable", "created_at", "updated_at"}
for _, field := range fields {
if !tab.HasColumn(field) {
return fmt.Errorf("%s is required", field)
Expand Down Expand Up @@ -767,7 +767,7 @@ func (conv *Xun) SaveAssistant(assistant map[string]interface{}) (interface{}, e
}

// Process JSON fields
jsonFields := []string{"tags", "options", "prompts", "flows", "files", "functions", "permissions", "placeholder"}
jsonFields := []string{"tags", "options", "prompts", "flows", "files", "tools", "permissions", "placeholder"}
for _, field := range jsonFields {
if val, ok := assistantCopy[field]; ok && val != nil {
// If it's a string, try to parse it first
Expand Down Expand Up @@ -954,7 +954,7 @@ func (conv *Xun) GetAssistants(filter AssistantFilter) (*AssistantResponse, erro

// Convert rows to map slice and parse JSON fields
data := make([]map[string]interface{}, len(rows))
jsonFields := []string{"tags", "options", "prompts", "flows", "files", "functions", "permissions", "placeholder"}
jsonFields := []string{"tags", "options", "prompts", "flows", "files", "tools", "permissions", "placeholder"}
for i, row := range rows {
data[i] = row
// Only parse JSON fields if they are selected or no select filter is provided
Expand Down Expand Up @@ -1008,7 +1008,7 @@ func (conv *Xun) GetAssistant(assistantID string) (map[string]interface{}, error
}

// Parse JSON fields
jsonFields := []string{"tags", "options", "prompts", "flows", "files", "functions", "permissions", "placeholder"}
jsonFields := []string{"tags", "options", "prompts", "flows", "files", "tools", "permissions", "placeholder"}
conv.parseJSONFields(data, jsonFields)

return data, nil
Expand Down
Loading

0 comments on commit 465a3fc

Please sign in to comment.