diff --git a/datasets.go b/datasets.go index c053a26..7b9f377 100644 --- a/datasets.go +++ b/datasets.go @@ -30,9 +30,10 @@ type Dataset struct { } type datasetRequest struct { - Input string `json:"input"` - Workspace string `json:"workspace"` - DatasetToolRepo string `json:"datasetToolRepo"` + Input string `json:"input"` + WorkspaceID string `json:"workspaceID"` + DatasetToolRepo string `json:"datasetToolRepo"` + Env []string `json:"env"` } type createDatasetArgs struct { @@ -47,6 +48,11 @@ type addDatasetElementArgs struct { ElementContent string `json:"elementContent"` } +type addDatasetElementsArgs struct { + DatasetID string `json:"datasetID"` + Elements []DatasetElement `json:"elements"` +} + type listDatasetElementArgs struct { DatasetID string `json:"datasetID"` } @@ -56,15 +62,16 @@ type getDatasetElementArgs struct { Element string `json:"element"` } -func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]DatasetMeta, error) { - if workspace == "" { - workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") +func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]DatasetMeta, error) { + if workspaceID == "" { + workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") } out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{ Input: "{}", - Workspace: workspace, + WorkspaceID: workspaceID, DatasetToolRepo: g.globalOpts.DatasetToolRepo, + Env: g.globalOpts.Env, }) if err != nil { return nil, err @@ -77,9 +84,9 @@ func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]Datas return datasets, nil } -func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, description string) (Dataset, error) { - if workspace == "" { - workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") +func (g *GPTScript) CreateDataset(ctx context.Context, workspaceID, name, description string) (Dataset, error) { + if workspaceID == "" { + workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") } args := createDatasetArgs{ @@ -93,8 +100,9 @@ func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, descript out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{ Input: string(argsJSON), - Workspace: workspace, + WorkspaceID: workspaceID, DatasetToolRepo: g.globalOpts.DatasetToolRepo, + Env: g.globalOpts.Env, }) if err != nil { return Dataset{}, err @@ -107,9 +115,9 @@ func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, descript return dataset, nil } -func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) { - if workspace == "" { - workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") +func (g *GPTScript) AddDatasetElement(ctx context.Context, workspaceID, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) { + if workspaceID == "" { + workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") } args := addDatasetElementArgs{ @@ -125,8 +133,9 @@ func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{ Input: string(argsJSON), - Workspace: workspace, + WorkspaceID: workspaceID, DatasetToolRepo: g.globalOpts.DatasetToolRepo, + Env: g.globalOpts.Env, }) if err != nil { return DatasetElementMeta{}, err @@ -139,9 +148,32 @@ func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, return element, nil } -func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetID string) ([]DatasetElementMeta, error) { - if workspace == "" { - workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") +func (g *GPTScript) AddDatasetElements(ctx context.Context, workspaceID, datasetID string, elements []DatasetElement) error { + if workspaceID == "" { + workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + args := addDatasetElementsArgs{ + DatasetID: datasetID, + Elements: elements, + } + argsJSON, err := json.Marshal(args) + if err != nil { + return fmt.Errorf("failed to marshal element args: %w", err) + } + + _, err = g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{ + Input: string(argsJSON), + WorkspaceID: workspaceID, + DatasetToolRepo: g.globalOpts.DatasetToolRepo, + Env: g.globalOpts.Env, + }) + return err +} + +func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datasetID string) ([]DatasetElementMeta, error) { + if workspaceID == "" { + workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") } args := listDatasetElementArgs{ @@ -154,8 +186,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetI out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{ Input: string(argsJSON), - Workspace: workspace, + WorkspaceID: workspaceID, DatasetToolRepo: g.globalOpts.DatasetToolRepo, + Env: g.globalOpts.Env, }) if err != nil { return nil, err @@ -168,9 +201,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetI return elements, nil } -func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID, elementName string) (DatasetElement, error) { - if workspace == "" { - workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") +func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetID, elementName string) (DatasetElement, error) { + if workspaceID == "" { + workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") } args := getDatasetElementArgs{ @@ -184,8 +217,9 @@ func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID, out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{ Input: string(argsJSON), - Workspace: workspace, + WorkspaceID: workspaceID, DatasetToolRepo: g.globalOpts.DatasetToolRepo, + Env: g.globalOpts.Env, }) if err != nil { return DatasetElement{}, err diff --git a/datasets_test.go b/datasets_test.go index 3763982..44d72ab 100644 --- a/datasets_test.go +++ b/datasets_test.go @@ -2,48 +2,72 @@ package gptscript import ( "context" - "os" "testing" "github.com/stretchr/testify/require" ) func TestDatasets(t *testing.T) { - workspace, err := os.MkdirTemp("", "go-gptscript-test") + workspaceID, err := g.CreateWorkspace(context.Background(), "directory") require.NoError(t, err) + defer func() { - _ = os.RemoveAll(workspace) + _ = g.DeleteWorkspace(context.Background(), DeleteWorkspaceOptions{WorkspaceID: workspaceID}) }() // Create a dataset - dataset, err := g.CreateDataset(context.Background(), workspace, "test-dataset", "This is a test dataset") + dataset, err := g.CreateDataset(context.Background(), workspaceID, "test-dataset", "This is a test dataset") require.NoError(t, err) require.Equal(t, "test-dataset", dataset.Name) require.Equal(t, "This is a test dataset", dataset.Description) require.Equal(t, 0, len(dataset.Elements)) // Add an element - elementMeta, err := g.AddDatasetElement(context.Background(), workspace, dataset.ID, "test-element", "This is a test element", "This is the content") + elementMeta, err := g.AddDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element", "This is a test element", "This is the content") require.NoError(t, err) require.Equal(t, "test-element", elementMeta.Name) require.Equal(t, "This is a test element", elementMeta.Description) - // Get the element - element, err := g.GetDatasetElement(context.Background(), workspace, dataset.ID, "test-element") + // Add two more + err = g.AddDatasetElements(context.Background(), workspaceID, dataset.ID, []DatasetElement{ + { + DatasetElementMeta: DatasetElementMeta{ + Name: "test-element-2", + Description: "This is a test element 2", + }, + Contents: "This is the content 2", + }, + { + DatasetElementMeta: DatasetElementMeta{ + Name: "test-element-3", + Description: "This is a test element 3", + }, + Contents: "This is the content 3", + }, + }) + require.NoError(t, err) + + // Get the first element + element, err := g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element") require.NoError(t, err) require.Equal(t, "test-element", element.Name) require.Equal(t, "This is a test element", element.Description) require.Equal(t, "This is the content", element.Contents) + // Get the third element + element, err = g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element-3") + require.NoError(t, err) + require.Equal(t, "test-element-3", element.Name) + require.Equal(t, "This is a test element 3", element.Description) + require.Equal(t, "This is the content 3", element.Contents) + // List elements in the dataset - elements, err := g.ListDatasetElements(context.Background(), workspace, dataset.ID) + elements, err := g.ListDatasetElements(context.Background(), workspaceID, dataset.ID) require.NoError(t, err) - require.Equal(t, 1, len(elements)) - require.Equal(t, "test-element", elements[0].Name) - require.Equal(t, "This is a test element", elements[0].Description) + require.Equal(t, 3, len(elements)) // List datasets - datasets, err := g.ListDatasets(context.Background(), workspace) + datasets, err := g.ListDatasets(context.Background(), workspaceID) require.NoError(t, err) require.Equal(t, 1, len(datasets)) require.Equal(t, "test-dataset", datasets[0].Name)