Skip to content

Commit 58f4717

Browse files
committed
FunctionCallを検証するためのgoコードを作成した
1 parent 51f446e commit 58f4717

12 files changed

+310
-1
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@
1919

2020
# Go workspace file
2121
go.work
22+
23+
#
24+
.history

ChatGPT.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
8+
openai "github.com/sashabaranov/go-openai"
9+
)
10+
11+
func ChatGPT(BuiltMessages []openai.ChatCompletionMessage) (openai.ChatCompletionResponse, error) {
12+
13+
// fmt.Printf("BuiltMessages=%s\n", BuiltMessages)
14+
15+
// Function call用のjsonファイルを読み出す
16+
funcDefs, err := LoadFunctionDefinitions("./function")
17+
if err != nil {
18+
return openai.ChatCompletionResponse{}, fmt.Errorf("Error loading function definitions: %v", err)
19+
}
20+
21+
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
22+
resp, err := client.CreateChatCompletion(
23+
context.Background(),
24+
openai.ChatCompletionRequest{
25+
Model: openai.GPT3Dot5Turbo,
26+
Messages: BuiltMessages,
27+
Functions: funcDefs,
28+
},
29+
)
30+
31+
if err != nil {
32+
return openai.ChatCompletionResponse{}, fmt.Errorf("ChatCompletion error: %v", err)
33+
}
34+
35+
return resp, nil
36+
}

README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1-
# ChatgptFunctioncallTestWithGolang
1+
# ChatgptFunctioncallTestWithGolang
2+
3+
Usage:
4+
```
5+
go run . "YOUR MESSAGES"
6+
```

build_messages.go

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package main
2+
3+
import (
4+
openai "github.com/sashabaranov/go-openai"
5+
)
6+
7+
func BuildMessages(Message string) (BuiltMessages []openai.ChatCompletionMessage) {
8+
9+
// fmt.Printf("MessageTest:\n%s\n",Message)
10+
11+
messages := []openai.ChatCompletionMessage{
12+
{
13+
Role: openai.ChatMessageRoleSystem,
14+
Content: `You are a helpful assistant.`,
15+
},
16+
{
17+
Role: openai.ChatMessageRoleUser,
18+
Content: Message,
19+
},
20+
}
21+
return messages
22+
}

function/get_official_documents.json

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
[
2+
{
3+
"name": "get_official_documents",
4+
"description": "Refer to the URL of the relevant document.",
5+
"parameters": {
6+
"type": "object",
7+
"properties": {
8+
"document_URL": {
9+
"type": "array",
10+
"description": "URL of the relevant official document, even if there is more than one.",
11+
"items": {
12+
"type": "string"
13+
}
14+
},
15+
"tag": {
16+
"type": "array",
17+
"description": "Useful tags to investigate. e.g.) AWS, EC2, payment method",
18+
"items": {
19+
"type": "string"
20+
}
21+
}
22+
},
23+
"required": [
24+
"document_URL",
25+
"tag"
26+
]
27+
}
28+
}
29+
]

function/get_recommended_coffee.json

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[
2+
{
3+
"name": "get_recommended_coffee",
4+
"description": "Recommended coffee",
5+
"parameters": {
6+
"type": "object",
7+
"properties": {
8+
"producing_country": {
9+
"type": "string",
10+
"description": "Random country name but make sure its produce coffee."
11+
}
12+
},
13+
"required": ["producing_country"]
14+
}
15+
}
16+
]

function/get_weather_information.json

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[
2+
{
3+
"name": "get_weather_information",
4+
"description": "Get the current weather in a given location",
5+
"parameters": {
6+
"type": "object",
7+
"properties": {
8+
"location": {
9+
"type": "string",
10+
"description": "The city and state, e.g. San Francisco, CA"
11+
},
12+
"unit": {
13+
"type": "string",
14+
"enum": ["celsius", "fahrenheit"]
15+
}
16+
},
17+
"required": ["location"]
18+
}
19+
}
20+
]

get.go

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package main
2+
3+
import (
4+
"io/ioutil"
5+
"log"
6+
"net/http"
7+
)
8+
9+
func GetRequest(URL string) string {
10+
resp, err := http.Get(URL)
11+
if err != nil {
12+
log.Fatal(err)
13+
}
14+
defer resp.Body.Close()
15+
16+
body, err := ioutil.ReadAll(resp.Body)
17+
if err != nil {
18+
log.Fatal(err)
19+
}
20+
21+
// log.Println(string(body))
22+
return string(body)
23+
}

go.mod

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module function_call
2+
3+
go 1.20
4+
5+
require github.com/sashabaranov/go-openai v1.13.0

go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
github.com/sashabaranov/go-openai v1.13.0 h1:EAusFfnhaMaaUspUZ2+MbB/ZcVeD4epJmTOlZ+8AcAE=
2+
github.com/sashabaranov/go-openai v1.13.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=

load_function.go

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io/ioutil"
7+
"strings"
8+
9+
openai "github.com/sashabaranov/go-openai"
10+
)
11+
12+
func LoadFunctionDefinitions(directory string) ([]openai.FunctionDefinition, error) {
13+
var funcDefs []openai.FunctionDefinition
14+
15+
files, err := ioutil.ReadDir(directory)
16+
if err != nil {
17+
return nil, fmt.Errorf("Error reading directory: %v", err)
18+
}
19+
20+
for _, f := range files {
21+
if strings.HasSuffix(f.Name(), ".json") {
22+
data, err := ioutil.ReadFile(directory + "/" + f.Name())
23+
if err != nil {
24+
return nil, fmt.Errorf("Error reading file %s: %v", f.Name(), err)
25+
}
26+
27+
// fmt.Printf("Loaded JSON from file %s: %s\n", f.Name(), string(data))
28+
29+
var defs []openai.FunctionDefinition
30+
err = json.Unmarshal(data, &defs)
31+
if err != nil {
32+
return nil, fmt.Errorf("Error unmarshalling function definitions from file %s: %v", f.Name(), err)
33+
}
34+
35+
funcDefs = append(funcDefs, defs...)
36+
}
37+
}
38+
// fmt.Printf("Loaded JSON: \n%s\n", funcDefs)
39+
40+
return funcDefs, nil
41+
}

main.go

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
)
8+
9+
//json to go
10+
11+
type Arguments struct {
12+
DocumentURL []string `json:"document_URL"`
13+
Tag []string `json:"tag"`
14+
}
15+
16+
func main() {
17+
functionCalls := make(map[string]interface{})
18+
message := os.Args[1]
19+
20+
i := 0
21+
for {
22+
23+
fmt.Printf("---- %d ----\n", i)
24+
25+
fmt.Printf("Input -> %s\n", message)
26+
27+
BuiltMessage := BuildMessages(message)
28+
29+
resp, err := ChatGPT(BuiltMessage)
30+
if err != nil {
31+
fmt.Printf("Error executing chat GPT: %v\n", err)
32+
return
33+
}
34+
35+
// FunctionCallが生成されてなければここでループを抜ける
36+
if resp.Choices[0].Message.FunctionCall == nil {
37+
fmt.Printf("Generated message -> %s\n", resp.Choices[0].Message)
38+
break
39+
}
40+
41+
fmt.Printf("Generated function -> %s\n", resp.Choices[0].Message.FunctionCall.Name)
42+
// fmt.Println(resp.Choices[0].Message.FunctionCall.Arguments)
43+
44+
functionName := resp.Choices[0].Message.FunctionCall.Name
45+
arguments := resp.Choices[0].Message.FunctionCall.Arguments
46+
47+
// 新しいFunction Callの関数名を辞書に追加
48+
functionCalls[functionName] = arguments
49+
50+
// 辞書全体の表示
51+
functionMapBytes, err := json.MarshalIndent(functionCalls, "", " ")
52+
if err != nil {
53+
fmt.Printf("Error marshalling functionMap: %v\n", err)
54+
return
55+
}
56+
57+
// 取り残したFunction Callを生成させるためのメッセージ
58+
message = fmt.Sprintf(`
59+
%s
60+
The above is a generated Function Call.
61+
It is generated based on the following text.
62+
---
63+
%s
64+
---
65+
Are there any other Function Calls?
66+
`, string(functionMapBytes), os.Args[1])
67+
68+
i++
69+
70+
}
71+
72+
fmt.Println("---- Done ----")
73+
74+
// for ループを抜けた後、functionCalls の内容を表示
75+
functionMapBytes, err := json.MarshalIndent(functionCalls, "", " ")
76+
fmt.Println("results")
77+
if err != nil {
78+
fmt.Printf("Error marshalling functionCalls: %v\n", err)
79+
return
80+
}
81+
fmt.Println(string(functionMapBytes))
82+
83+
// fmt.Println("resp.Choices[0].Message ->", resp.Choices[0].Message)
84+
85+
// // FunctionCallがあるか調べる。
86+
// if resp.Choices[0].Message.FunctionCall != nil {
87+
// // fmt.Println(resp.Choices[0].Message)
88+
// fmt.Println(resp.Choices[0].Message.FunctionCall.Name)
89+
// fmt.Println(resp.Choices[0].Message.FunctionCall.Arguments)
90+
91+
// args := &Arguments{}
92+
// err := json.Unmarshal([]byte(resp.Choices[0].Message.FunctionCall.Arguments), args)
93+
// if err != nil {
94+
// fmt.Printf("Error decoding JSON: %v\n", err)
95+
// return
96+
// }
97+
98+
// if resp.Choices[0].Message.FunctionCall.Name == "get_official_documents" {
99+
// for i, url := range args.DocumentURL {
100+
// fmt.Printf("URL %d: %s\n", i, url)
101+
102+
// fmt.Printf("%s\n", GetRequest(url))
103+
// }
104+
// }
105+
// }
106+
// GetRequest("https://www.yahoo.co.jp/")
107+
}

0 commit comments

Comments
 (0)