Skip to content

Commit a2ca01b

Browse files
authored
feat: implement new fine tuning job API (#479)
* feat: implement new fine tuning job API * fix: export ListFineTuningJobEventsParameter * fix: lint errors * fix: test errors * fix: code test coverage * fix: code test coverage * fix: use any * chore: use url.Values
1 parent a14bc10 commit a2ca01b

File tree

3 files changed

+255
-0
lines changed

3 files changed

+255
-0
lines changed

client_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
223223
{"ListFineTuneEvents", func() (any, error) {
224224
return client.ListFineTuneEvents(ctx, "")
225225
}},
226+
{"CreateFineTuningJob", func() (any, error) {
227+
return client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
228+
}},
229+
{"CancelFineTuningJob", func() (any, error) {
230+
return client.CancelFineTuningJob(ctx, "")
231+
}},
232+
{"RetrieveFineTuningJob", func() (any, error) {
233+
return client.RetrieveFineTuningJob(ctx, "")
234+
}},
235+
{"ListFineTuningJobEvents", func() (any, error) {
236+
return client.ListFineTuningJobEvents(ctx, "")
237+
}},
226238
{"Moderations", func() (any, error) {
227239
return client.Moderations(ctx, ModerationRequest{})
228240
}},

fine_tuning_job.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"net/url"
8+
)
9+
10+
type FineTuningJob struct {
11+
ID string `json:"id"`
12+
Object string `json:"object"`
13+
CreatedAt int64 `json:"created_at"`
14+
FinishedAt int64 `json:"finished_at"`
15+
Model string `json:"model"`
16+
FineTunedModel string `json:"fine_tuned_model,omitempty"`
17+
OrganizationID string `json:"organization_id"`
18+
Status string `json:"status"`
19+
Hyperparameters Hyperparameters `json:"hyperparameters"`
20+
TrainingFile string `json:"training_file"`
21+
ValidationFile string `json:"validation_file,omitempty"`
22+
ResultFiles []string `json:"result_files"`
23+
TrainedTokens int `json:"trained_tokens"`
24+
}
25+
26+
type Hyperparameters struct {
27+
Epochs int `json:"n_epochs"`
28+
}
29+
30+
type FineTuningJobRequest struct {
31+
TrainingFile string `json:"training_file"`
32+
ValidationFile string `json:"validation_file,omitempty"`
33+
Model string `json:"model,omitempty"`
34+
Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"`
35+
Suffix string `json:"suffix,omitempty"`
36+
}
37+
38+
type FineTuningJobEventList struct {
39+
Object string `json:"object"`
40+
Data []FineTuneEvent `json:"data"`
41+
HasMore bool `json:"has_more"`
42+
}
43+
44+
type FineTuningJobEvent struct {
45+
Object string `json:"object"`
46+
ID string `json:"id"`
47+
CreatedAt int `json:"created_at"`
48+
Level string `json:"level"`
49+
Message string `json:"message"`
50+
Data any `json:"data"`
51+
Type string `json:"type"`
52+
}
53+
54+
// CreateFineTuningJob create a fine tuning job.
55+
func (c *Client) CreateFineTuningJob(
56+
ctx context.Context,
57+
request FineTuningJobRequest,
58+
) (response FineTuningJob, err error) {
59+
urlSuffix := "/fine_tuning/jobs"
60+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request))
61+
if err != nil {
62+
return
63+
}
64+
65+
err = c.sendRequest(req, &response)
66+
return
67+
}
68+
69+
// CancelFineTuningJob cancel a fine tuning job.
70+
func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) {
71+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel"))
72+
if err != nil {
73+
return
74+
}
75+
76+
err = c.sendRequest(req, &response)
77+
return
78+
}
79+
80+
// RetrieveFineTuningJob retrieve a fine tuning job.
81+
func (c *Client) RetrieveFineTuningJob(
82+
ctx context.Context,
83+
fineTuningJobID string,
84+
) (response FineTuningJob, err error) {
85+
urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID)
86+
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
87+
if err != nil {
88+
return
89+
}
90+
91+
err = c.sendRequest(req, &response)
92+
return
93+
}
94+
95+
type listFineTuningJobEventsParameters struct {
96+
after *string
97+
limit *int
98+
}
99+
100+
type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters)
101+
102+
func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter {
103+
return func(args *listFineTuningJobEventsParameters) {
104+
args.after = &after
105+
}
106+
}
107+
108+
func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter {
109+
return func(args *listFineTuningJobEventsParameters) {
110+
args.limit = &limit
111+
}
112+
}
113+
114+
// ListFineTuningJobs list fine tuning jobs events.
115+
func (c *Client) ListFineTuningJobEvents(
116+
ctx context.Context,
117+
fineTuningJobID string,
118+
setters ...ListFineTuningJobEventsParameter,
119+
) (response FineTuningJobEventList, err error) {
120+
parameters := &listFineTuningJobEventsParameters{
121+
after: nil,
122+
limit: nil,
123+
}
124+
125+
for _, setter := range setters {
126+
setter(parameters)
127+
}
128+
129+
urlValues := url.Values{}
130+
if parameters.after != nil {
131+
urlValues.Add("after", *parameters.after)
132+
}
133+
if parameters.limit != nil {
134+
urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit))
135+
}
136+
137+
encodedValues := ""
138+
if len(urlValues) > 0 {
139+
encodedValues = "?" + urlValues.Encode()
140+
}
141+
142+
req, err := c.newRequest(
143+
ctx,
144+
http.MethodGet,
145+
c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues),
146+
)
147+
if err != nil {
148+
return
149+
}
150+
151+
err = c.sendRequest(req, &response)
152+
return
153+
}

fine_tuning_job_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package openai_test
2+
3+
import (
4+
"context"
5+
6+
. "github.com/sashabaranov/go-openai"
7+
"github.com/sashabaranov/go-openai/internal/test/checks"
8+
9+
"encoding/json"
10+
"fmt"
11+
"net/http"
12+
"testing"
13+
)
14+
15+
const testFineTuninigJobID = "fine-tuning-job-id"
16+
17+
// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server.
18+
func TestFineTuningJob(t *testing.T) {
19+
client, server, teardown := setupOpenAITestServer()
20+
defer teardown()
21+
server.RegisterHandler(
22+
"/v1/fine_tuning/jobs",
23+
func(w http.ResponseWriter, r *http.Request) {
24+
var resBytes []byte
25+
resBytes, _ = json.Marshal(FineTuningJob{})
26+
fmt.Fprintln(w, string(resBytes))
27+
},
28+
)
29+
30+
server.RegisterHandler(
31+
"/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel",
32+
func(w http.ResponseWriter, r *http.Request) {
33+
resBytes, _ := json.Marshal(FineTuningJob{})
34+
fmt.Fprintln(w, string(resBytes))
35+
},
36+
)
37+
38+
server.RegisterHandler(
39+
"/v1/fine_tuning/jobs/"+testFineTuninigJobID,
40+
func(w http.ResponseWriter, r *http.Request) {
41+
var resBytes []byte
42+
resBytes, _ = json.Marshal(FineTuningJob{})
43+
fmt.Fprintln(w, string(resBytes))
44+
},
45+
)
46+
47+
server.RegisterHandler(
48+
"/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events",
49+
func(w http.ResponseWriter, r *http.Request) {
50+
resBytes, _ := json.Marshal(FineTuningJobEventList{})
51+
fmt.Fprintln(w, string(resBytes))
52+
},
53+
)
54+
55+
ctx := context.Background()
56+
57+
_, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
58+
checks.NoError(t, err, "CreateFineTuningJob error")
59+
60+
_, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID)
61+
checks.NoError(t, err, "CancelFineTuningJob error")
62+
63+
_, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID)
64+
checks.NoError(t, err, "RetrieveFineTuningJob error")
65+
66+
_, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID)
67+
checks.NoError(t, err, "ListFineTuningJobEvents error")
68+
69+
_, err = client.ListFineTuningJobEvents(
70+
ctx,
71+
testFineTuninigJobID,
72+
ListFineTuningJobEventsWithAfter("last-event-id"),
73+
)
74+
checks.NoError(t, err, "ListFineTuningJobEvents error")
75+
76+
_, err = client.ListFineTuningJobEvents(
77+
ctx,
78+
testFineTuninigJobID,
79+
ListFineTuningJobEventsWithLimit(10),
80+
)
81+
checks.NoError(t, err, "ListFineTuningJobEvents error")
82+
83+
_, err = client.ListFineTuningJobEvents(
84+
ctx,
85+
testFineTuninigJobID,
86+
ListFineTuningJobEventsWithAfter("last-event-id"),
87+
ListFineTuningJobEventsWithLimit(10),
88+
)
89+
checks.NoError(t, err, "ListFineTuningJobEvents error")
90+
}

0 commit comments

Comments
 (0)