Skip to content

Commit 3117e9e

Browse files
zenovoreAlexander
and
Alexander
authored
refactor: adding context and merge artifact service creation into mlflow (caraml-dev#80)
* refactor: adding context and merge artifact service creation into mlflow * refactor: mlflow unit test * refactor: change context argument sequence * refactor: change unit test context sequence * refactor: naming service and client * refactor: remove gcs config for context * refactor: change definition order, struct naming --------- Co-authored-by: Alexander <[email protected]>
1 parent 543add5 commit 3117e9e

File tree

5 files changed

+177
-83
lines changed

5 files changed

+177
-83
lines changed

api/pkg/artifact/artifact.go

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,24 @@ import (
88
"google.golang.org/api/iterator"
99
)
1010

11-
type gcsClient struct {
12-
API *storage.Client
13-
Config Config
14-
}
15-
type Config struct {
16-
Ctx context.Context
17-
}
18-
1911
type Service interface {
20-
DeleteArtifact(url string) error
12+
DeleteArtifact(ctx context.Context, url string) error
2113
}
2214

23-
func NewGcsClient(api *storage.Client, cfg Config) Service {
24-
return &gcsClient{
25-
API: api,
26-
Config: cfg,
27-
}
15+
type GcsArtifactClient struct {
16+
API *storage.Client
2817
}
2918

30-
func (gc *gcsClient) DeleteArtifact(url string) error {
19+
func (gac *GcsArtifactClient) DeleteArtifact(ctx context.Context, url string) error {
3120
// Get bucket name and gcsPrefix
3221
// the [5:] is to remove the "gs://" on the artifact uri
3322
// ex : gs://bucketName/path → bucketName/path
34-
gcsBucket, gcsLocation := gc.getGcsBucketAndLocation(url[5:])
23+
gcsBucket, gcsLocation := gac.getGcsBucketAndLocation(url[5:])
3524

3625
// Sets the name for the bucket.
37-
bucket := gc.API.Bucket(gcsBucket)
26+
bucket := gac.API.Bucket(gcsBucket)
3827

39-
it := bucket.Objects(gc.Config.Ctx, &storage.Query{
28+
it := bucket.Objects(ctx, &storage.Query{
4029
Prefix: gcsLocation,
4130
})
4231
for {
@@ -47,16 +36,32 @@ func (gc *gcsClient) DeleteArtifact(url string) error {
4736
if err != nil {
4837
return err
4938
}
50-
if err := bucket.Object(attrs.Name).Delete(gc.Config.Ctx); err != nil {
39+
if err := bucket.Object(attrs.Name).Delete(ctx); err != nil {
5140
return err
5241
}
5342
}
5443
return nil
5544
}
5645

57-
func (gc *gcsClient) getGcsBucketAndLocation(str string) (string, string) {
46+
func (gac *GcsArtifactClient) getGcsBucketAndLocation(str string) (string, string) {
5847
// Split string using delimiter
5948
// ex : bucketName/path/path1/item → (bucketName , path/path1/item)
6049
splitStr := strings.SplitN(str, "/", 2)
6150
return splitStr[0], splitStr[1]
6251
}
52+
53+
func NewGcsArtifactClient(api *storage.Client) Service {
54+
return &GcsArtifactClient{
55+
API: api,
56+
}
57+
}
58+
59+
type NopArtifactClient struct{}
60+
61+
func (nac *NopArtifactClient) DeleteArtifact(ctx context.Context, url string) error {
62+
return nil
63+
}
64+
65+
func NewNopArtifactClient() Service {
66+
return &NopArtifactClient{}
67+
}

api/pkg/artifact/mocks/artifact.go

Lines changed: 16 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/pkg/client/mlflow/mlflow.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,21 @@ package mlflow
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"fmt"
78
"net/http"
89

10+
"cloud.google.com/go/storage"
11+
912
"github.com/gojek/mlp/api/pkg/artifact"
1013
)
1114

1215
type Service interface {
1316
searchRunsForExperiment(ExperimentID string) (SearchRunsResponse, error)
1417
searchRunData(RunID string) (SearchRunResponse, error)
15-
DeleteExperiment(ExperimentID string, deleteArtifact bool) error
16-
DeleteRun(RunID, artifactURL string, deleteArtifact bool) error
18+
DeleteExperiment(ctx context.Context, ExperimentID string, deleteArtifact bool) error
19+
DeleteRun(ctx context.Context, RunID, artifactURL string, deleteArtifact bool) error
1720
}
1821

1922
type mlflowService struct {
@@ -22,12 +25,24 @@ type mlflowService struct {
2225
Config Config
2326
}
2427

25-
func NewMlflowService(httpClient *http.Client, config Config, artifactService artifact.Service) Service {
28+
func NewMlflowService(httpClient *http.Client, config Config) (Service, error) {
29+
var artifactService artifact.Service
30+
if config.ArtifactServiceType == "nop" {
31+
artifactService = artifact.NewNopArtifactClient()
32+
} else if config.ArtifactServiceType == "gcs" {
33+
api, err := storage.NewClient(context.Background())
34+
if err != nil {
35+
return &mlflowService{}, fmt.Errorf("failed initializing gcs for mlflow delete package")
36+
}
37+
artifactService = artifact.NewGcsArtifactClient(api)
38+
} else {
39+
return &mlflowService{}, fmt.Errorf("invalid artifact service type")
40+
}
2641
return &mlflowService{
2742
API: httpClient,
2843
Config: config,
2944
ArtifactService: artifactService,
30-
}
45+
}, nil
3146
}
3247

3348
func (mfs *mlflowService) httpCall(method string, url string, body []byte, response interface{}) error {
@@ -101,19 +116,19 @@ func (mfs *mlflowService) searchRunData(RunID string) (SearchRunResponse, error)
101116
return runResponse, nil
102117
}
103118

104-
func (mfs *mlflowService) DeleteExperiment(ExperimentID string, deleteArtifact bool) error {
119+
func (mfs *mlflowService) DeleteExperiment(ctx context.Context, ExperimentID string, deleteArtifact bool) error {
105120

106121
relatedRunData, err := mfs.searchRunsForExperiment(ExperimentID)
107122
if err != nil {
108123
return err
109124
}
110125
// Error handling for empty/no run for the experiment
111126
if len(relatedRunData.RunsData) == 0 {
112-
return fmt.Errorf("There are no related run for experiment id %s", ExperimentID)
127+
return fmt.Errorf("there are no related run for experiment id %s", ExperimentID)
113128
}
114129
// Error Handling, when a RunID failed to delete return error
115130
for _, run := range relatedRunData.RunsData {
116-
err = mfs.DeleteRun(run.Info.RunID, run.Info.ArtifactURI, deleteArtifact)
131+
err = mfs.DeleteRun(ctx, run.Info.RunID, run.Info.ArtifactURI, deleteArtifact)
117132
if err != nil {
118133
return fmt.Errorf("deletion failed for run_id %s for experiment id %s: %s", run.Info.RunID, ExperimentID, err)
119134
}
@@ -122,7 +137,7 @@ func (mfs *mlflowService) DeleteExperiment(ExperimentID string, deleteArtifact b
122137
return nil
123138
}
124139

125-
func (mfs *mlflowService) DeleteRun(RunID, artifactURL string, deleteArtifact bool) error {
140+
func (mfs *mlflowService) DeleteRun(ctx context.Context, RunID, artifactURL string, deleteArtifact bool) error {
126141
if artifactURL == "" {
127142
runDetail, err := mfs.searchRunData(RunID)
128143
if err != nil {
@@ -131,7 +146,7 @@ func (mfs *mlflowService) DeleteRun(RunID, artifactURL string, deleteArtifact bo
131146
artifactURL = runDetail.RunData.Info.ArtifactURI
132147
}
133148
if deleteArtifact {
134-
err := mfs.ArtifactService.DeleteArtifact(artifactURL)
149+
err := mfs.ArtifactService.DeleteArtifact(ctx, artifactURL)
135150
if err != nil {
136151
return err
137152
}

0 commit comments

Comments
 (0)