diff --git a/README.md b/README.md index 2d60f8b..e91734d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # workspace-provider -There are two providers that can be used to create and manage workspaces: directory and S3. +There are three providers that can be used to create and manage workspaces: directory, S3, and Azure. ## Directory @@ -21,4 +21,26 @@ You must set the following environment variables: ### Usage with S3-compatible providers (e.g. Cloudflare R2) You can use the above referenced AWS environment variables to configure the S3 provider, setting the value of the environment variable to the corresponding value from your provider. -Additionally, you should also set the `WORKSPACE_PROVIDER_S3_BASE_ENDPOINT` environment variable to the endpoint of your provider. For example, if you are using Cloudflare R2, you can set `WORKSPACE_PROVIDER_S3_BASE_ENDPOINT` to `https://.r2.cloudflarestorage.com`. \ No newline at end of file +Additionally, you should also set the `WORKSPACE_PROVIDER_S3_BASE_ENDPOINT` environment variable to the endpoint of your provider. For example, if you are using Cloudflare R2, you can set `WORKSPACE_PROVIDER_S3_BASE_ENDPOINT` to `https://.r2.cloudflarestorage.com`. + +## Azure + +The Azure provider provides an Azure Blob Storage-based workspace. + +### Setup + +1. Create an Azure Storage Account in the [Azure Portal](https://portal.azure.com) +2. Create a container in your storage account +3. Get the connection string from your storage account (under "Access keys") + +### Configuration + +You must set the following environment variables: +- `WORKSPACE_PROVIDER_AZURE_CONTAINER` - The name of your Azure Storage container +- `WORKSPACE_PROVIDER_AZURE_CONNECTION_STRING` - The connection string for your Azure Storage account + +For example: +```bash +export WORKSPACE_PROVIDER_AZURE_CONTAINER="your-container-name" +export WORKSPACE_PROVIDER_AZURE_CONNECTION_STRING="DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=core.windows.net" +``` diff --git a/go.mod b/go.mod index 7bca9fd..18974e0 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/gptscript-ai/workspace-provider go 1.23.2 require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.10.0 + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.1 github.com/adrg/xdg v0.5.0 github.com/aws/aws-sdk-go-v2 v1.32.2 github.com/aws/aws-sdk-go-v2/config v1.27.43 @@ -16,6 +18,7 @@ require ( ) require ( + github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.41 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect @@ -43,5 +46,6 @@ require ( github.com/spf13/pflag v1.0.5 // indirect golang.org/x/net v0.31.0 // indirect golang.org/x/sys v0.27.0 // indirect + golang.org/x/text v0.20.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d8de053..8a48404 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,15 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.10.0 h1:n1DH8TPV4qqPTje2RcUBYwtrTWlabVp4n46+74X2pn4= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.10.0/go.mod h1:HDcZnuGbiyppErN6lB+idp4CKhjbc8gwjto6OPpyggM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1 h1:sO0/P7g68FrryJzljemN+6GTssUXdANk6aJ7T1ZxnsQ= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1/go.mod h1:h8hyGFDsU5HMivxiS2iYFZsgDbU9OnnJ163x5UGVKYo= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 h1:LqbJ/WzJUwBf8UiaSzgX7aMclParm9/5Vgp+TY51uBQ= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.5.0 h1:AifHbc4mg0x9zW52WOpKbsHaDKuRhlI7TVl47thgQ70= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.5.0/go.mod h1:T5RfihdXtBDxt1Ch2wobif3TvzTdumDy29kahv6AV9A= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.1 h1:fXPMAmuh0gDuRDey0atC8cXBuKIlqCzCkL8sm1n9Ov0= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.1/go.mod h1:SUZc9YRRHfx2+FAQKNDGrssXehqLpxmwRv2mC/5ntj4= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.1 h1:DzHpqpoJVaCgOUdVHxE8QB52S6NiVdDQvGlny1qvPqA= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/adrg/xdg v0.5.0 h1:dDaZvhMXatArP1NPHhnfaQUqWBLBsmx1h1HXQdMoFCY= github.com/adrg/xdg v0.5.0/go.mod h1:dDdY4M4DF9Rjy4kHPeNL+ilVF+p2lK8IdM9/rTSGcI4= github.com/aws/aws-sdk-go-v2 v1.32.2 h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI= @@ -39,6 +51,8 @@ github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxY github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/gabriel-vasile/mimetype v1.4.7 h1:SKFKl7kD0RiPdbht0s7hFtjl489WcQ1VyPW8ZzUMYCA= github.com/gabriel-vasile/mimetype v1.4.7/go.mod h1:GDlAgAyIRT27BhFl53XNAFtfjzOkLaF35JdEG0P7LtU= github.com/getkin/kin-openapi v0.124.0 h1:VSFNMB9C9rTKBnQ/fpyDU8ytMTr4dWI9QovSKj9kz/M= @@ -49,6 +63,8 @@ github.com/go-openapi/swag v0.22.8 h1:/9RjDSQ0vbFR+NyjGMkFTsA1IA0fmhKSThmfGZjicb github.com/go-openapi/swag v0.22.8/go.mod h1:6QT22icPLEqAM/z/TChgb4WAveCHF92+2gF0CNjHpPI= github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= +github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/safeopen v0.0.0-20240125081138-66b54d5181c6 h1:XBC2BmsUTvyeeahMA2wdvIaaKpYrr7za7F6lNK+0oL8= github.com/google/safeopen v0.0.0-20240125081138-66b54d5181c6/go.mod h1:D59KewtQCiD2Avi8N/v2zb/xTYaefwJl+ux2ejB58GQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -67,12 +83,16 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= @@ -86,13 +106,19 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/cli/workspace.go b/pkg/cli/workspace.go index 33434a6..bb369e2 100644 --- a/pkg/cli/workspace.go +++ b/pkg/cli/workspace.go @@ -10,10 +10,12 @@ import ( ) type workspaceProvider struct { - Provider string `usage:"The workspace provider to use, valid options are 'directory' and 's3'" default:"directory" env:"WORKSPACE_PROVIDER_PROVIDER,PROVIDER"` - DataHome string `usage:"The data home directory or bucket name" env:"WORKSPACE_PROVIDER_DATA_HOME"` - S3Bucket string `usage:"The S3 bucket name" name:"s3-bucket" env:"WORKSPACE_PROVIDER_S3_BUCKET"` - S3BaseEndpoint string `usage:"The S3 base endpoint to use with S3 compatible providers" name:"s3-base-endpoint" env:"WORKSPACE_PROVIDER_S3_BASE_ENDPOINT"` + Provider string `usage:"The workspace provider to use, valid options are 'directory' and 's3'" default:"directory" env:"WORKSPACE_PROVIDER_PROVIDER,PROVIDER"` + DataHome string `usage:"The data home directory or bucket name" env:"WORKSPACE_PROVIDER_DATA_HOME"` + S3Bucket string `usage:"The S3 bucket name" name:"s3-bucket" env:"WORKSPACE_PROVIDER_S3_BUCKET"` + S3BaseEndpoint string `usage:"The S3 base endpoint to use with S3 compatible providers" name:"s3-base-endpoint" env:"WORKSPACE_PROVIDER_S3_BASE_ENDPOINT"` + AzureContainer string `usage:"The Azure container name" name:"azure-container" env:"WORKSPACE_PROVIDER_AZURE_CONTAINER"` + AzureConnectionString string `usage:"The Azure connection string" name:"azure-connection-string" env:"WORKSPACE_PROVIDER_AZURE_CONNECTION_STRING"` client *client.Client } @@ -55,15 +57,24 @@ func (w *workspaceProvider) PersistentPre(cmd *cobra.Command, _ []string) error if w.S3Bucket == "" { return fmt.Errorf("s3 provider requires a bucket name") } + case client.AzureProvider: + if w.AzureContainer == "" { + return fmt.Errorf("azure provider requires a container name") + } + if w.AzureConnectionString == "" { + return fmt.Errorf("azure provider requires a connection string") + } default: return fmt.Errorf("invalid workspace provider: %s", w.Provider) } var err error w.client, err = client.New(cmd.Context(), client.Options{ - DirectoryDataHome: w.DataHome, - S3BucketName: w.S3Bucket, - S3BaseEndpoint: w.S3BaseEndpoint, + DirectoryDataHome: w.DataHome, + S3BucketName: w.S3Bucket, + S3BaseEndpoint: w.S3BaseEndpoint, + AzureContainerName: w.AzureContainer, + AzureConnectionString: w.AzureConnectionString, }) return err diff --git a/pkg/client/azure.go b/pkg/client/azure.go new file mode 100644 index 0000000..2a4cea3 --- /dev/null +++ b/pkg/client/azure.go @@ -0,0 +1,380 @@ +package client + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" + "github.com/gabriel-vasile/mimetype" + "github.com/google/uuid" +) + +func newAzure(containerName, connectionString string) (workspaceFactory, error) { + client, err := azblob.NewClientFromConnectionString(connectionString, nil) + if err != nil { + return nil, err + } + + return &azureProvider{ + containerName: containerName, + client: client, + revisionsProvider: &azureProvider{ + containerName: containerName, + dir: revisionsDir, + client: client, + }, + }, nil +} + +type azureProvider struct { + containerName, dir string + client *azblob.Client + revisionsProvider *azureProvider +} + +func (a *azureProvider) validatePath(path string) error { + if path == "" { + return nil // empty path is valid in some contexts (e.g., Ls root) + } + + // Check for path traversal attempts + if strings.Contains(path, "..") { + return fmt.Errorf("invalid path: must not contain '..'") + } + + // Check for absolute paths + if strings.HasPrefix(path, "/") { + return fmt.Errorf("invalid path: must be relative") + } + + // Azure Blob Storage naming rules: + // - Cannot start or end with '/' + // - Cannot contain consecutive forward slashes + if strings.HasSuffix(path, "/") { + return fmt.Errorf("invalid path: cannot end with '/'") + } + if strings.Contains(path, "//") { + return fmt.Errorf("invalid path: cannot contain consecutive '/'") + } + + // Additional Azure Blob Storage restrictions + if len(path) > 1024 { + return fmt.Errorf("invalid path: length cannot exceed 1024 characters") + } + + // Check for invalid characters in path segments + for _, segment := range strings.Split(path, "/") { + if segment == "" { + continue + } + if strings.ContainsAny(segment, `\:*?"<>|`) { + return fmt.Errorf("invalid path: contains invalid characters") + } + } + + return nil +} + +func (a *azureProvider) New(id string) (workspaceClient, error) { + container, dir, _ := strings.Cut(strings.TrimPrefix(id, AzureProvider+"://"), "/") + if dir == revisionsDir { + return nil, errors.New("cannot create a workspace client for the revisions directory") + } + + return &azureProvider{ + containerName: container, + dir: dir, + client: a.client, + revisionsProvider: &azureProvider{ + containerName: container, + dir: fmt.Sprintf("%s/%s", revisionsDir, dir), + client: a.client, + }, + }, nil +} + +func (a *azureProvider) Create() string { + return AzureProvider + "://" + filepath.Join(a.containerName, uuid.NewString()) +} + +func (a *azureProvider) Rm(ctx context.Context, id string) error { + container, dir, _ := strings.Cut(strings.TrimPrefix(id, AzureProvider+"://"), "/") + + newA := &azureProvider{ + containerName: container, + dir: dir, + client: a.client, + revisionsProvider: &azureProvider{ + containerName: container, + dir: fmt.Sprintf("%s/%s", revisionsDir, dir), + client: a.client, + }, + } + + // Best effort + _ = newA.revisionsProvider.RemoveAllWithPrefix(ctx, "") + + return newA.RemoveAllWithPrefix(ctx, "") +} + +func (a *azureProvider) Ls(ctx context.Context, prefix string) ([]string, error) { + if err := a.validatePath(prefix); err != nil { + return nil, err + } + if prefix != "" { + prefix = fmt.Sprintf("%s/%s/", a.dir, strings.TrimSuffix(prefix, "/")) + } else { + prefix = fmt.Sprintf("%s/", a.dir) + } + + containerClient := a.client.ServiceClient().NewContainerClient(a.containerName) + pager := containerClient.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ + Prefix: &prefix, + }) + + var files []string + for pager.More() { + resp, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + + for _, blob := range resp.Segment.BlobItems { + files = append(files, strings.TrimPrefix(*blob.Name, a.dir+"/")) + } + } + + return files, nil +} + +func (a *azureProvider) DeleteFile(ctx context.Context, filePath string) error { + if err := a.validatePath(filePath); err != nil { + return err + } + blobClient := a.client.ServiceClient().NewContainerClient(a.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", a.dir, filePath)) + _, err := blobClient.Delete(ctx, nil) + if err != nil { + var storageErr *azcore.ResponseError + if errors.As(err, &storageErr) && storageErr.StatusCode == 404 { + return nil + } + return err + } + + if a.revisionsProvider == nil { + return nil + } + + info, err := getRevisionInfo(ctx, a.revisionsProvider, filePath) + if err != nil { + return err + } + + for i := info.CurrentID; i > 0; i-- { + // Best effort + _ = deleteRevision(ctx, a.revisionsProvider, filePath, fmt.Sprintf("%d", i)) + } + + // Best effort + _ = deleteRevisionInfo(ctx, a.revisionsProvider, filePath) + + return nil +} + +func (a *azureProvider) OpenFile(ctx context.Context, filePath string, opt OpenOptions) (*File, error) { + if err := a.validatePath(filePath); err != nil { + return nil, err + } + blobClient := a.client.ServiceClient().NewContainerClient(a.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", a.dir, filePath)) + + resp, err := blobClient.DownloadStream(ctx, nil) + if err != nil { + var storageErr *azcore.ResponseError + if errors.As(err, &storageErr) && storageErr.StatusCode == 404 { + return nil, newNotFoundError(fmt.Sprintf("%s://%s/%s", AzureProvider, a.containerName, a.dir), filePath) + } + return nil, err + } + + var revision string + if opt.WithLatestRevisionID { + rev, err := getRevisionInfo(ctx, a.revisionsProvider, filePath) + if err != nil { + return nil, fmt.Errorf("failed to get revision info: %w", err) + } + revision = fmt.Sprintf("%d", rev.CurrentID) + } + + return &File{ + ReadCloser: resp.Body, + RevisionID: revision, + }, nil +} + +func (a *azureProvider) WriteFile(ctx context.Context, fileName string, reader io.Reader, opt WriteOptions) error { + if err := a.validatePath(fileName); err != nil { + return err + } + if a.revisionsProvider != nil && (opt.CreateRevision == nil || *opt.CreateRevision) { + info, err := getRevisionInfo(ctx, a.revisionsProvider, fileName) + if err != nil { + if nfe := (*NotFoundError)(nil); !errors.As(err, &nfe) { + return err + } + } + + if opt.LatestRevisionID != "" { + requiredLatestRevision, err := strconv.ParseInt(opt.LatestRevisionID, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse latest revision for write: %w", err) + } + + if requiredLatestRevision != info.CurrentID { + return newConflictError(AzureProvider+"://"+a.containerName, fileName, opt.LatestRevisionID, fmt.Sprintf("%d", info.CurrentID)) + } + } + + info.CurrentID++ + if err = writeRevision(ctx, a.revisionsProvider, a, fileName, info); err != nil { + if nfe := (*NotFoundError)(nil); !errors.As(err, &nfe) { + return fmt.Errorf("failed to write revision: %w", err) + } + } + + if err = writeRevisionInfo(ctx, a.revisionsProvider, fileName, info); err != nil { + return fmt.Errorf("failed to write revision info: %w", err) + } + } + + data, err := io.ReadAll(reader) + if err != nil { + return err + } + + blobClient := a.client.ServiceClient().NewContainerClient(a.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", a.dir, fileName)) + _, err = blobClient.UploadStream(ctx, bytes.NewReader(data), nil) + return err +} + +func (a *azureProvider) StatFile(ctx context.Context, fileName string, opt StatOptions) (FileInfo, error) { + if err := a.validatePath(fileName); err != nil { + return FileInfo{}, err + } + blobClient := a.client.ServiceClient().NewContainerClient(a.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", a.dir, fileName)) + + props, err := blobClient.GetProperties(ctx, nil) + if err != nil { + var storageErr *azcore.ResponseError + if errors.As(err, &storageErr) && storageErr.StatusCode == 404 { + return FileInfo{}, newNotFoundError(fmt.Sprintf("%s://%s/%s", AzureProvider, a.containerName, a.dir), fileName) + } + return FileInfo{}, err + } + + var mime string + if props.ContentType != nil { + mime = *props.ContentType + } + + // Get the first 3072 bytes of the blob to detect the mimetype + downloadOpts := &azblob.DownloadStreamOptions{} + downloadOpts.Range.Offset = 0 + downloadOpts.Range.Count = 3072 + resp, err := blobClient.DownloadStream(ctx, downloadOpts) + if err == nil { + defer resp.Body.Close() + mt, err := mimetype.DetectReader(resp.Body) + if err == nil { + mime = strings.Split(mt.String(), ";")[0] + } + } + + var modTime time.Time + if props.LastModified != nil { + modTime = *props.LastModified + } + + var revision string + if opt.WithLatestRevisionID { + rev, err := getRevisionInfo(ctx, a.revisionsProvider, fileName) + if err != nil { + return FileInfo{}, err + } + revision = fmt.Sprintf("%d", rev.CurrentID) + } + + return FileInfo{ + WorkspaceID: fmt.Sprintf("%s://%s/%s", AzureProvider, a.containerName, a.dir), + Name: strings.TrimPrefix(fileName, a.dir+"/"), + Size: *props.ContentLength, + ModTime: modTime, + MimeType: mime, + RevisionID: revision, + }, nil +} + +func (a *azureProvider) RemoveAllWithPrefix(ctx context.Context, prefix string) error { + if err := a.validatePath(prefix); err != nil { + return err + } + if prefix != "" { + prefix = fmt.Sprintf("%s/%s/", a.dir, strings.TrimSuffix(prefix, "/")) + } else { + prefix = fmt.Sprintf("%s/", a.dir) + } + + containerClient := a.client.ServiceClient().NewContainerClient(a.containerName) + pager := containerClient.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ + Prefix: &prefix, + }) + + for pager.More() { + resp, err := pager.NextPage(ctx) + if err != nil { + return err + } + + for _, blob := range resp.Segment.BlobItems { + blobClient := containerClient.NewBlockBlobClient(*blob.Name) + if _, err := blobClient.Delete(ctx, nil); err != nil { + return err + } + } + } + + return nil +} + +func (a *azureProvider) ListRevisions(ctx context.Context, fileName string) ([]RevisionInfo, error) { + if err := a.validatePath(fileName); err != nil { + return nil, err + } + return listRevisions(ctx, a.revisionsProvider, fmt.Sprintf("%s://%s/%s", AzureProvider, a.containerName, a.dir), fileName) +} + +func (a *azureProvider) GetRevision(ctx context.Context, fileName, revisionID string) (*File, error) { + if err := a.validatePath(fileName); err != nil { + return nil, err + } + return getRevision(ctx, a.revisionsProvider, fileName, revisionID) +} + +func (a *azureProvider) DeleteRevision(ctx context.Context, fileName, revisionID string) error { + if err := a.validatePath(fileName); err != nil { + return err + } + return deleteRevision(ctx, a.revisionsProvider, fileName, revisionID) +} + +func (a *azureProvider) RevisionClient() workspaceClient { + return a.revisionsProvider +} diff --git a/pkg/client/azure_test.go b/pkg/client/azure_test.go new file mode 100644 index 0000000..924cde4 --- /dev/null +++ b/pkg/client/azure_test.go @@ -0,0 +1,912 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "reflect" + "sort" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +func TestCreateAndRmAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + id := azureFactory.Create() + if !strings.HasPrefix(id, AzureProvider+"://") { + t.Errorf("unexpected id: %s", id) + } + + container, dir, _ := strings.Cut(strings.TrimPrefix(id, AzureProvider+"://"), "/") + testAzureProvider := &azureProvider{ + containerName: container, + client: azurePrv.client, + } + + // Nothing should be created + blobClient := testAzureProvider.client.ServiceClient().NewContainerClient(container).NewBlockBlobClient(dir) + if _, err := blobClient.GetProperties(context.Background(), nil); err == nil { + t.Errorf("expected error when checking if workspace exists") + } else { + var storageErr *azcore.ResponseError + if !errors.As(err, &storageErr) || storageErr.StatusCode != 404 { + t.Errorf("unexpected error when checking if workspace exists: %v", err) + } + } + + if err := azureFactory.Rm(context.Background(), id); err != nil { + t.Errorf("unexpected error when removing workspace: %v", err) + } +} + +func TestWriteAndDeleteFileInAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + // Copy a file into the workspace + if err := azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // Ensure the file actually exists + blobClient := azurePrv.client.ServiceClient().NewContainerClient(azurePrv.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", azurePrv.dir, "test.txt")) + props, err := blobClient.GetProperties(context.Background(), nil) + if err != nil { + t.Errorf("error when checking if file exists: %v", err) + } + + // Stat the file and compare with the original + providerStat, err := azurePrv.StatFile(context.Background(), "test.txt", StatOptions{}) + if err != nil { + t.Errorf("unexpected error when statting file: %v", err) + } + + if providerStat.WorkspaceID != azureTestingID { + t.Errorf("unexpected workspace id: %s", providerStat.WorkspaceID) + } + if providerStat.Size != *props.ContentLength { + t.Errorf("unexpected file size: %d", providerStat.Size) + } + if providerStat.Name != "test.txt" { + t.Errorf("unexpected file name: %s", providerStat.Name) + } + if providerStat.ModTime.Compare(*props.LastModified) != 0 { + t.Errorf("unexpected file mod time: %s", providerStat.ModTime) + } + + // Delete the file + if err := azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("unexpected error when deleting file: %v", err) + } + + // Ensure the file no longer exists + if _, err := blobClient.GetProperties(context.Background(), nil); err == nil { + t.Errorf("file should not exist after deleting") + } +} + +func TestWriteAndDeleteFileInAzureWithSubDir(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + filePath := filepath.Join("subdir", "test.txt") + // Copy a file into the workspace + if err := azurePrv.WriteFile(context.Background(), filePath, strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // Ensure the file actually exists + blobClient := azurePrv.client.ServiceClient().NewContainerClient(azurePrv.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", azurePrv.dir, filePath)) + if _, err := blobClient.GetProperties(context.Background(), nil); err != nil { + t.Errorf("error when checking if file exists: %v", err) + } + + // Delete the file + if err := azurePrv.DeleteFile(context.Background(), filePath); err != nil { + t.Errorf("unexpected error when deleting file: %v", err) + } + + // Ensure the file no longer exists + if _, err := blobClient.GetProperties(context.Background(), nil); err == nil { + t.Errorf("file should not exist after deleting") + } +} + +func TestFileReadFromAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + if err := azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + readFile, err := azurePrv.OpenFile(context.Background(), "test.txt", OpenOptions{}) + if err != nil { + t.Errorf("unexpected error when reading file: %v", err) + } + + content, err := io.ReadAll(readFile) + if err != nil { + t.Errorf("unexpected error when reading file: %v", err) + } + + if err = readFile.Close(); err != nil { + t.Errorf("error closing file: %v", err) + } + + if string(content) != "test" { + t.Errorf("unexpected content: %s", string(content)) + } + + // Delete the file + if err = azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("unexpected error when deleting file: %v", err) + } + + // Deleting the file again should not throw an error + if err = azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("unexpected error when deleting file: %v", err) + } +} + +func TestLsAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + // Write a bunch of files to the directory. They can be blank + for i := range 7 { + fileName := fmt.Sprintf("test%d.txt", i) + if err := azurePrv.WriteFile(context.Background(), fileName, strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // deferring here is fine because these files shouldn't be deleted until the end of the test + defer func(name string) { + err := azurePrv.DeleteFile(context.Background(), name) + if err != nil { + t.Errorf("unexpected error when deleting file %s: %v", name, err) + } + }(fileName) + } + + contents, err := azurePrv.Ls(context.Background(), "") + if err != nil { + t.Fatalf("unexpected error when listing files: %v", err) + } + + if len(contents) != 7 { + t.Errorf("unexpected number of files: %d", len(contents)) + } + + sort.Strings(contents) + if !reflect.DeepEqual( + contents, + []string{ + "test0.txt", + "test1.txt", + "test2.txt", + "test3.txt", + "test4.txt", + "test5.txt", + "test6.txt", + }, + ) { + t.Errorf("unexpected contents: %v", contents) + } +} + +func TestLsWithSubDirsAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + defer func() { + err := azurePrv.RemoveAllWithPrefix(context.Background(), "testDir") + if err != nil { + t.Errorf("unexpected error when deleting file %s: %v", "testDir", err) + } + }() + + // Write a bunch of files to the directory. They can be blank + for i := range 7 { + fileName := fmt.Sprintf("test%d.txt", i) + if i >= 3 { + fileName = fmt.Sprintf("testDir/%s", fileName) + } + if err := azurePrv.WriteFile(context.Background(), fileName, strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // deferring here is fine because these files shouldn't be deleted until the end of the test + defer func(name string) { + err := azurePrv.DeleteFile(context.Background(), name) + if err != nil { + t.Errorf("unexpected error when deleting file %s: %v", name, err) + } + }(fileName) + } + + contents, err := azurePrv.Ls(context.Background(), "") + if err != nil { + t.Fatalf("unexpected error when listing files: %v", err) + } + + if len(contents) != 7 { + t.Errorf("unexpected number of children: %d", len(contents)) + } + + sort.Strings(contents) + if !reflect.DeepEqual( + contents, + []string{ + "test0.txt", + "test1.txt", + "test2.txt", + filepath.Join("testDir", "test3.txt"), + filepath.Join("testDir", "test4.txt"), + filepath.Join("testDir", "test5.txt"), + filepath.Join("testDir", "test6.txt"), + }, + ) { + t.Errorf("unexpected contents: %v", contents) + } +} + +func TestLsWithPrefixAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + defer func() { + err := azurePrv.RemoveAllWithPrefix(context.Background(), "testDir") + if err != nil { + t.Errorf("unexpected error when deleting file %s: %v", "testDir", err) + } + }() + + // Write a bunch of files to the directory. They can be blank + for i := range 7 { + fileName := fmt.Sprintf("test%d.txt", i) + if i >= 3 { + fileName = fmt.Sprintf("testDir/%s", fileName) + } + if err := azurePrv.WriteFile(context.Background(), fileName, strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // deferring here is fine because these files shouldn't be deleted until the end of the test + defer func(name string) { + err := azurePrv.DeleteFile(context.Background(), name) + if err != nil { + t.Errorf("unexpected error when deleting file %s: %v", name, err) + } + }(fileName) + } + + contents, err := azurePrv.Ls(context.Background(), "testDir") + if err != nil { + t.Fatalf("unexpected error when listing files: %v", err) + } + + if len(contents) != 4 { + t.Errorf("unexpected number of contents: %d", len(contents)) + } + + sort.Strings(contents) + if !reflect.DeepEqual( + contents, + []string{ + filepath.Join("testDir", "test3.txt"), + filepath.Join("testDir", "test4.txt"), + filepath.Join("testDir", "test5.txt"), + filepath.Join("testDir", "test6.txt"), + }, + ) { + t.Errorf("unexpected contents: %v", contents) + } +} + +func TestRemoveAllWithPrefixAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + // Write a bunch of files to the directory. They can be blank + for i := range 7 { + fileName := fmt.Sprintf("test%d.txt", i) + if i >= 3 { + fileName = fmt.Sprintf("testDir/%s", fileName) + } + if err := azurePrv.WriteFile(context.Background(), fileName, strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // deferring here is fine because these files shouldn't be deleted until the end of the test + defer func(name string) { + err := azurePrv.DeleteFile(context.Background(), name) + if fnf := (*NotFoundError)(nil); err != nil && !errors.As(err, &fnf) { + t.Errorf("unexpected error when deleting file %s: %v", name, err) + } + }(fileName) + } + + err := azurePrv.RemoveAllWithPrefix(context.Background(), "testDir") + if err != nil { + t.Errorf("unexpected error when deleting all with prefix testDir: %v", err) + } + + contents, err := azurePrv.Ls(context.Background(), "") + if err != nil { + t.Fatalf("unexpected error when listing files: %v", err) + } + + if len(contents) != 3 { + t.Errorf("unexpected number of children: %d", len(contents)) + } + + sort.Strings(contents) + if !reflect.DeepEqual( + contents, + []string{ + "test0.txt", + "test1.txt", + "test2.txt", + }, + ) { + t.Errorf("unexpected contents: %v", contents) + } +} + +func TestOpeningFileDNENoErrorAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + var notFoundError *NotFoundError + if file, err := azurePrv.OpenFile(context.Background(), "test.txt", OpenOptions{}); err == nil { + _ = file.Close() + t.Errorf("expected error when opening file that doesn't exist") + } else if !errors.As(err, ¬FoundError) { + t.Errorf("expected not found error when opening file that doesn't exist") + } +} + +// Add revision-related tests +func TestWriteEnsureRevisionAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + // Copy a file into the workspace + if err := azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // List revisions, there should be none + revisions, err := azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 0 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Update the file + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test2"), WriteOptions{}); err != nil { + t.Errorf("error getting file to write: %v", err) + } + + // Ensure the revision file exists + blobClient := azurePrv.client.ServiceClient().NewContainerClient(azurePrv.containerName).NewBlockBlobClient(fmt.Sprintf("revisions/%s/%s.1", azurePrv.dir, "test.txt")) + props, err := blobClient.GetProperties(context.Background(), nil) + if err != nil { + t.Errorf("error when checking if revision exists: %v", err) + } + + // Now there should be one revision + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 1 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } else { + if revisions[0].WorkspaceID != azureTestingID { + t.Errorf("unexpected workspace id: %s", revisions[0].WorkspaceID) + } + if revisions[0].Size != *props.ContentLength { + t.Errorf("unexpected file size: %d", revisions[0].Size) + } + if revisions[0].Name != "test.txt" { + t.Errorf("unexpected file name: %s", revisions[0].Name) + } + if revisions[0].ModTime.Compare(*props.LastModified) != 0 { + t.Errorf("unexpected file mod time: %s", revisions[0].ModTime) + } + if revisions[0].RevisionID != "1" { + t.Errorf("unexpected revision id: %s", revisions[0].RevisionID) + } + + // Get the revision and ensure that it has the correct content + rev, err := azurePrv.GetRevision(context.Background(), "test.txt", revisions[0].RevisionID) + if err != nil { + t.Errorf("unexpected error when getting revision: %v", err) + } else { + defer rev.Close() + } + + content, err := io.ReadAll(rev) + if err != nil { + t.Errorf("unexpected error when reading revision: %v", err) + } + + if string(content) != "test" { + t.Errorf("unexpected content: %s", string(content)) + } + + revisionID, err := rev.GetRevisionID() + if err != nil { + t.Errorf("error getting revision: %v", err) + } + if revisionID != "1" { + t.Errorf("unexpected revision ID: %s", revisionID) + } + } + + // Delete the file + if err = azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("unexpected error when deleting file: %v", err) + } + + // Ensure the file no longer exists + if _, err = azurePrv.client.ServiceClient().NewContainerClient(azurePrv.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", azurePrv.dir, "test.txt")).GetProperties(context.Background(), nil); err == nil { + t.Errorf("file should not exist after deleting") + } + + // Ensure the revision file no longer exists + if _, err = azurePrv.client.ServiceClient().NewContainerClient(azurePrv.containerName).NewBlockBlobClient(fmt.Sprintf("revisions/%s/%s.1", azurePrv.dir, "test.txt")).GetProperties(context.Background(), nil); err == nil { + t.Errorf("revision should not exist after deleting") + } + + // Ensure the API returns no revisions for the file + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 0 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } +} + +func TestWriteEnsureConflictAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + // Copy a file into the workspace + if err := azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // List revisions, there should be none + revisions, err := azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 0 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + ce := (*ConflictError)(nil) + // Trying to update the file with a non-zero revision ID should fail with a conflict error + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test2"), WriteOptions{LatestRevisionID: "1"}); err == nil || !errors.As(err, &ce) { + t.Errorf("expected error when first updating file non-zero revision ID: %v", err) + } + + // Also, using -1 for the revision ID should also fail because that is the same as "only write if the file doesn't exist" + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test2"), WriteOptions{LatestRevisionID: "-1"}); err == nil || !errors.As(err, &ce) { + t.Errorf("expected error when first updating file non-zero revision ID: %v", err) + } + + // Update the file + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test2"), WriteOptions{}); err != nil { + t.Errorf("error getting file to write: %v", err) + } + + // Now there should be one revision + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 1 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Update the file again + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test3"), WriteOptions{LatestRevisionID: revisions[0].RevisionID}); err != nil { + t.Errorf("error getting file to write: %v", err) + } + + // Now there should be two revisions + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 2 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Trying to update the file again with the same revision ID should fail with a conflict error + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test4"), WriteOptions{LatestRevisionID: revisions[0].RevisionID}); err == nil || !errors.As(err, &ce) { + t.Errorf("expected error when updating file with same revision ID: %v", err) + } + + latestRevisionID := revisions[1].RevisionID + // Delete the most recent revision + if err = azurePrv.DeleteRevision(context.Background(), "test.txt", latestRevisionID); err != nil { + t.Errorf("error deleting revision: %v", err) + } + + // Now there should be one revision + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 1 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // We cannot update the file with this revision ID + ce = (*ConflictError)(nil) + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test5"), WriteOptions{LatestRevisionID: revisions[0].RevisionID}); err == nil || !errors.As(err, &ce) { + t.Errorf("expected error when updating file with zero revision ID: %v", err) + } + + // Ensure that we can still create a new revision + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test5"), WriteOptions{LatestRevisionID: latestRevisionID}); err != nil { + t.Errorf("error getting file to write: %v", err) + } + + // Delete the file + if err = azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("error removing file: %v", err) + } +} + +func TestDeleteRevisionAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + // Copy a file into the workspace + if err := azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // List revisions, there should be none + revisions, err := azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 0 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Update the file + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test2"), WriteOptions{}); err != nil { + t.Errorf("error getting file to write: %v", err) + } + + // Now there should be one revision + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 1 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Update the file + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test3"), WriteOptions{}); err != nil { + t.Errorf("error getting file to write: %v", err) + } + + // Now there should be two revisions + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 2 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Delete the first revision + if err = azurePrv.DeleteRevision(context.Background(), "test.txt", "1"); err != nil { + t.Errorf("unexpected error when deleting revision: %v", err) + } + + // Now there should be one revision + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 1 || revisions[0].RevisionID != "2" { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Deleting the revision again should not produce an error + if err = azurePrv.DeleteRevision(context.Background(), "test.txt", "1"); err != nil { + t.Errorf("unexpected error when deleting revision: %v", err) + } + + // Delete the file + if err = azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("unexpected error when deleting file: %v", err) + } + + // Ensure the file no longer exists + if _, err = azurePrv.client.ServiceClient().NewContainerClient(azurePrv.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", azurePrv.dir, "test.txt")).GetProperties(context.Background(), nil); err == nil { + t.Errorf("file should not exist after deleting") + } + + // Ensure the revision file no longer exists + if _, err = azurePrv.client.ServiceClient().NewContainerClient(azurePrv.containerName).NewBlockBlobClient(fmt.Sprintf("revisions/%s/%s.2", azurePrv.dir, "test.txt")).GetProperties(context.Background(), nil); err == nil { + t.Errorf("revision should not exist after deleting") + } +} + +func TestNoCreateRevisionsClientAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + _, err := azureFactory.New(fmt.Sprintf("%s://%s/%s", AzureProvider, os.Getenv("WORKSPACE_PROVIDER_AZURE_CONTAINER"), revisionsDir)) + if err == nil { + t.Errorf("expected error when creating client for revisions dir") + } +} + +func TestPathValidationAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + tests := []struct { + name string + path string + wantErr bool + errMsg string + testFunc func(string) error + }{ + // Path traversal tests + {"traversal parent", "../test.txt", true, "must not contain '..'", nil}, + {"traversal nested", "foo/../../test.txt", true, "must not contain '..'", nil}, + {"traversal with slash", "../test.txt/", true, "must not contain '..'", nil}, + + // Absolute path tests + {"absolute path", "/test.txt", true, "must be relative", nil}, + {"absolute nested", "/foo/test.txt", true, "must be relative", nil}, + + // Azure naming rule tests + {"trailing slash", "test/", true, "cannot end with '/'", nil}, + {"double slash", "foo//bar.txt", true, "cannot contain consecutive '/'", nil}, + {"invalid chars", "test*.txt", true, "contains invalid characters", nil}, + {"invalid chars nested", "foo/test*.txt", true, "contains invalid characters", nil}, + {"long path", strings.Repeat("a/", 1000) + "a.txt", true, "length cannot exceed 1024", nil}, + + // Valid paths + {"simple file", "test.txt", false, "", nil}, + {"nested file", "foo/bar/test.txt", false, "", nil}, + {"with numbers", "test123.txt", false, "", nil}, + {"with dash", "test-file.txt", false, "", nil}, + {"with underscore", "test_file.txt", false, "", nil}, + } + + // Create a test file to verify existence checks + if err := azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error creating test file: %v", err) + } + defer func() { + if err := azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("error deleting test file: %v", err) + } + }() + + for _, tt := range tests { + t.Run(fmt.Sprintf("WriteFile/%s", tt.name), func(t *testing.T) { + err := azurePrv.WriteFile(context.Background(), tt.path, strings.NewReader("test"), WriteOptions{}) + assertPathError(t, err, tt.wantErr, tt.errMsg) + }) + + t.Run(fmt.Sprintf("OpenFile/%s", tt.name), func(t *testing.T) { + _, err := azurePrv.OpenFile(context.Background(), tt.path, OpenOptions{}) + assertPathError(t, err, tt.wantErr, tt.errMsg) + }) + + t.Run(fmt.Sprintf("StatFile/%s", tt.name), func(t *testing.T) { + _, err := azurePrv.StatFile(context.Background(), tt.path, StatOptions{}) + assertPathError(t, err, tt.wantErr, tt.errMsg) + }) + + t.Run(fmt.Sprintf("DeleteFile/%s", tt.name), func(t *testing.T) { + err := azurePrv.DeleteFile(context.Background(), tt.path) + assertPathError(t, err, tt.wantErr, tt.errMsg) + }) + + t.Run(fmt.Sprintf("Ls/%s", tt.name), func(t *testing.T) { + _, err := azurePrv.Ls(context.Background(), tt.path) + assertPathError(t, err, tt.wantErr, tt.errMsg) + }) + + t.Run(fmt.Sprintf("RemoveAllWithPrefix/%s", tt.name), func(t *testing.T) { + err := azurePrv.RemoveAllWithPrefix(context.Background(), tt.path) + assertPathError(t, err, tt.wantErr, tt.errMsg) + }) + + t.Run(fmt.Sprintf("ListRevisions/%s", tt.name), func(t *testing.T) { + _, err := azurePrv.ListRevisions(context.Background(), tt.path) + assertPathError(t, err, tt.wantErr, tt.errMsg) + }) + + t.Run(fmt.Sprintf("GetRevision/%s", tt.name), func(t *testing.T) { + if tt.wantErr { + _, err := azurePrv.GetRevision(context.Background(), tt.path, "1") + assertPathError(t, err, tt.wantErr, tt.errMsg) + } + }) + + t.Run(fmt.Sprintf("DeleteRevision/%s", tt.name), func(t *testing.T) { + err := azurePrv.DeleteRevision(context.Background(), tt.path, "1") + assertPathError(t, err, tt.wantErr, tt.errMsg) + }) + } +} + +// Helper function to assert path validation errors +func assertPathError(t *testing.T, err error, wantErr bool, errMsg string) { + t.Helper() + if wantErr { + if err == nil { + t.Error("expected error but got none") + return + } + if !strings.Contains(err.Error(), errMsg) { + t.Errorf("expected error containing %q, got %v", errMsg, err) + } + } else if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestReadFileWithRevisionAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + // Copy a file into the workspace + if err := azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test"), WriteOptions{}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // Read the file + f, err := azurePrv.OpenFile(context.Background(), "test.txt", OpenOptions{WithLatestRevisionID: true}) + if err != nil { + t.Errorf("error reading file: %v", err) + } + + // Read the file contents + data, err := io.ReadAll(f) + if err != nil { + t.Errorf("error reading file contents: %v", err) + } + + // Close the file + if err := f.Close(); err != nil { + t.Errorf("error closing file: %v", err) + } + + if string(data) != "test" { + t.Errorf("unexpected file contents: %s", string(data)) + } + + // Ensure that the revision is set and correct + revisionID, err := f.GetRevisionID() + if err != nil { + t.Errorf("error getting revision: %v", err) + } + if revisionID != "0" { + t.Errorf("unexpected revision ID: %s", revisionID) + } + + // Update the file + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test2"), WriteOptions{LatestRevisionID: "0"}); err != nil { + t.Errorf("error getting file to write: %v", err) + } + + // Read the file + f, err = azurePrv.OpenFile(context.Background(), "test.txt", OpenOptions{WithLatestRevisionID: true}) + if err != nil { + t.Errorf("error reading file: %v", err) + } + + // Read the file contents + data, err = io.ReadAll(f) + if err != nil { + t.Errorf("error reading file contents: %v", err) + } + + // Close the file + if err := f.Close(); err != nil { + t.Errorf("error closing file: %v", err) + } + + if string(data) != "test2" { + t.Errorf("unexpected file contents: %s", string(data)) + } + + // Get the revision ID + revisionID, err = f.GetRevisionID() + if err != nil { + t.Errorf("error getting revision: %v", err) + } + if revisionID != "1" { + t.Errorf("unexpected revision ID: %s", revisionID) + } + + // Delete the file + if err = azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("error removing file: %v", err) + } +} + +func TestWriteEnsureNoRevisionAzure(t *testing.T) { + if skipAzureTests { + t.Skip("Skipping Azure tests") + } + + createRevision := false + // Copy a file into the workspace + if err := azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test"), WriteOptions{CreateRevision: &createRevision}); err != nil { + t.Fatalf("error getting file to write: %v", err) + } + + // List revisions, there should be none + revisions, err := azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 0 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Update the file + if err = azurePrv.WriteFile(context.Background(), "test.txt", strings.NewReader("test2"), WriteOptions{CreateRevision: &createRevision}); err != nil { + t.Errorf("error getting file to write: %v", err) + } + + // Now there should still be no revision + revisions, err = azurePrv.ListRevisions(context.Background(), "test.txt") + if err != nil { + t.Errorf("unexpected error when listing revisions: %v", err) + } + if len(revisions) != 0 { + t.Errorf("unexpected number of revisions: %d", len(revisions)) + } + + // Delete the file + if err = azurePrv.DeleteFile(context.Background(), "test.txt"); err != nil { + t.Errorf("unexpected error when deleting file: %v", err) + } +} diff --git a/pkg/client/client.go b/pkg/client/client.go index 3cad988..8ee6f0c 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -16,6 +16,7 @@ import ( const ( DirectoryProvider = "directory" S3Provider = "s3" + AzureProvider = "azure" ) type workspaceFactory interface { @@ -38,9 +39,11 @@ type workspaceClient interface { } type Options struct { - DirectoryDataHome string - S3BucketName string - S3BaseEndpoint string + DirectoryDataHome string + S3BucketName string + S3BaseEndpoint string + AzureContainerName string + AzureConnectionString string } func complete(opts ...Options) Options { @@ -56,6 +59,12 @@ func complete(opts ...Options) Options { if o.S3BaseEndpoint != "" { opt.S3BaseEndpoint = o.S3BaseEndpoint } + if o.AzureContainerName != "" { + opt.AzureContainerName = o.AzureContainerName + } + if o.AzureConnectionString != "" { + opt.AzureConnectionString = o.AzureConnectionString + } } if opt.DirectoryDataHome == "" { @@ -68,19 +77,27 @@ func complete(opts ...Options) Options { func New(ctx context.Context, opts ...Options) (*Client, error) { opt := complete(opts...) - var s3 workspaceFactory + factories := map[string]workspaceFactory{ + DirectoryProvider: newDirectory(opt.DirectoryDataHome), + } + if opt.S3BucketName != "" { - var err error - s3, err = newS3(ctx, opt.S3BucketName, opt.S3BaseEndpoint) + factory, err := newS3(ctx, opt.S3BucketName, opt.S3BaseEndpoint) if err != nil { return nil, err } + factories[S3Provider] = factory } + if opt.AzureConnectionString != "" { + factory, err := newAzure(opt.AzureContainerName, opt.AzureConnectionString) + if err != nil { + return nil, err + } + factories[AzureProvider] = factory + } + return &Client{ - factories: map[string]workspaceFactory{ - DirectoryProvider: newDirectory(opt.DirectoryDataHome), - S3Provider: s3, - }, + factories: factories, }, nil } diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 5edf5de..77f5e88 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -13,20 +13,30 @@ import ( ) var c, _ = New(context.Background(), Options{ - S3BucketName: os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET"), - S3BaseEndpoint: os.Getenv("WORKSPACE_PROVIDER_S3_BASE_ENDPOINT"), + S3BucketName: os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET"), + S3BaseEndpoint: os.Getenv("WORKSPACE_PROVIDER_S3_BASE_ENDPOINT"), + AzureContainerName: os.Getenv("WORKSPACE_PROVIDER_AZURE_CONTAINER"), + AzureConnectionString: os.Getenv("WORKSPACE_PROVIDER_AZURE_CONNECTION_STRING"), }) func TestProviders(t *testing.T) { providers := c.Providers() for _, p := range providers { - if p != DirectoryProvider && p != S3Provider { + if p != DirectoryProvider && p != S3Provider && p != AzureProvider { t.Errorf("invalid provider: %s", p) } } - if len(providers) != 2 { + expectedCount := 1 + if !skipAzureTests { + expectedCount++ + } + if !skipS3Tests { + expectedCount++ + } + + if len(providers) != expectedCount { t.Errorf("unexpected number of providers: %d", len(providers)) } } diff --git a/pkg/client/directory_test.go b/pkg/client/directory_test.go index 27cdc8e..ee95cad 100644 --- a/pkg/client/directory_test.go +++ b/pkg/client/directory_test.go @@ -16,11 +16,15 @@ import ( var ( directoryFactory workspaceFactory s3Factory workspaceFactory + azureFactory workspaceFactory directoryTestingID string s3TestingID string + azureTestingID string dirPrv workspaceClient s3Prv *s3Provider + azurePrv *azureProvider skipS3Tests = os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" + skipAzureTests = os.Getenv("WORKSPACE_PROVIDER_AZURE_CONNECTION_STRING") == "" || os.Getenv("WORKSPACE_PROVIDER_AZURE_CONTAINER") == "" ) func TestMain(m *testing.M) { @@ -37,6 +41,15 @@ func TestMain(m *testing.M) { s3Prv = s3Client.(*s3Provider) } + if !skipAzureTests { + azureFactory, _ = newAzure(os.Getenv("WORKSPACE_PROVIDER_AZURE_CONTAINER"), os.Getenv("WORKSPACE_PROVIDER_AZURE_CONNECTION_STRING")) + // This won't ever error because it doesn't create anything. + azureTestingID = azureFactory.Create() + + azureClient, _ := azureFactory.New(azureTestingID) + azurePrv = azureClient.(*azureProvider) + } + exitCode := m.Run() var errs []error @@ -50,6 +63,12 @@ func TestMain(m *testing.M) { } } + if !skipAzureTests { + if err := azureFactory.Rm(context.Background(), azureTestingID); err != nil { + errs = append(errs, fmt.Errorf("error removing azure workspace: %v", err)) + } + } + if err := errors.Join(errs...); err != nil { panic(err) }