diff --git a/ociregistry/ocifilter/select.go b/ociregistry/ocifilter/select.go index 64f75fc5..e396e78b 100644 --- a/ociregistry/ocifilter/select.go +++ b/ociregistry/ocifilter/select.go @@ -21,136 +21,182 @@ import ( "cuelabs.dev/go/oci/ociregistry" ) -// Select returns a wrapper for r that provides only -// repositories for which allow returns true. +// AccessKind +type AccessKind int + +const ( + // [ociregistry.Reader] methods. + AccessRead AccessKind = iota + + // [ociregistry.Writer] methods. + AccessWrite + + // [ociregistry.Deleter] methods. + AccessDelete + + // [ociregistry.Lister] methods. + AccessList +) + +// AccessChecker returns a wrapper for r that invokes check +// to check access before calling an underlying method. Only if check succeeds will +// the underlying method be called. // -// Requests for disallowed repositories will return ErrNameUnknown -// errors on read and ErrDenied on write. -func Select(r ociregistry.Interface, allow func(repoName string) bool) ociregistry.Interface { - return &selectRegistry{ - allow: allow, +// The check function is invoked with the name of the repository being +// accessed (or "*" for Repositories), and the kind of access required. +// For some methods (e.g. Mount), check might be invoked more than +// once for a given repository. +// +// When invoking the Repositories method, check is invoked for each repository in +// the iteration - the repository will be omitted if check returns an error. +func AccessChecker(r ociregistry.Interface, check func(repoName string, access AccessKind) error) ociregistry.Interface { + return &accessCheckerRegistry{ + check: check, r: r, } } -type selectRegistry struct { +type accessCheckerRegistry struct { // Embed Funcs rather than the interface directly so that // if new methods are added and selectRegistry isn't updated, // we fall back to returning an error rather than passing through the method. *ociregistry.Funcs - allow func(repoName string) bool + check func(repoName string, kind AccessKind) error r ociregistry.Interface } -func (r *selectRegistry) GetBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) { - if !r.allow(repo) { - return nil, ociregistry.ErrNameUnknown +// Select returns a wrapper for r that provides only +// repositories for which allow returns true. +// +// Requests for disallowed repositories will return ErrNameUnknown +// errors on read and ErrDenied on write. +func Select(r ociregistry.Interface, allow func(repoName string) bool) ociregistry.Interface { + return AccessChecker(r, func(repoName string, access AccessKind) error { + if allow(repoName) { + return nil + } + if access == AccessWrite { + return ociregistry.ErrDenied + } + if access == AccessList && repoName == "*" { + return nil + } + return ociregistry.ErrNameUnknown + }) +} + +func (r *accessCheckerRegistry) GetBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) { + if err := r.check(repo, AccessRead); err != nil { + return nil, err } return r.r.GetBlob(ctx, repo, digest) } -func (r *selectRegistry) GetBlobRange(ctx context.Context, repo string, digest ociregistry.Digest, offset0, offset1 int64) (ociregistry.BlobReader, error) { - if !r.allow(repo) { - return nil, ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) GetBlobRange(ctx context.Context, repo string, digest ociregistry.Digest, offset0, offset1 int64) (ociregistry.BlobReader, error) { + if err := r.check(repo, AccessRead); err != nil { + return nil, err } return r.r.GetBlobRange(ctx, repo, digest, offset0, offset1) } -func (r *selectRegistry) GetManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) { - if !r.allow(repo) { - return nil, ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) GetManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) { + if err := r.check(repo, AccessRead); err != nil { + return nil, err } return r.r.GetManifest(ctx, repo, digest) } -func (r *selectRegistry) GetTag(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) { - if !r.allow(repo) { - return nil, ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) GetTag(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) { + if err := r.check(repo, AccessRead); err != nil { + return nil, err } return r.r.GetTag(ctx, repo, tagName) } -func (r *selectRegistry) ResolveBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) { - if !r.allow(repo) { - return ociregistry.Descriptor{}, ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) ResolveBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) { + if err := r.check(repo, AccessRead); err != nil { + return ociregistry.Descriptor{}, err } return r.r.ResolveBlob(ctx, repo, digest) } -func (r *selectRegistry) ResolveManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) { - if !r.allow(repo) { - return ociregistry.Descriptor{}, ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) ResolveManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) { + if err := r.check(repo, AccessRead); err != nil { + return ociregistry.Descriptor{}, err } return r.r.ResolveManifest(ctx, repo, digest) } -func (r *selectRegistry) ResolveTag(ctx context.Context, repo string, tagName string) (ociregistry.Descriptor, error) { - if !r.allow(repo) { - return ociregistry.Descriptor{}, ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) ResolveTag(ctx context.Context, repo string, tagName string) (ociregistry.Descriptor, error) { + if err := r.check(repo, AccessRead); err != nil { + return ociregistry.Descriptor{}, err } return r.r.ResolveTag(ctx, repo, tagName) } -func (r *selectRegistry) PushBlob(ctx context.Context, repo string, desc ociregistry.Descriptor, rd io.Reader) (ociregistry.Descriptor, error) { - if !r.allow(repo) { - return ociregistry.Descriptor{}, ociregistry.ErrDenied +func (r *accessCheckerRegistry) PushBlob(ctx context.Context, repo string, desc ociregistry.Descriptor, rd io.Reader) (ociregistry.Descriptor, error) { + if err := r.check(repo, AccessWrite); err != nil { + return ociregistry.Descriptor{}, err } return r.r.PushBlob(ctx, repo, desc, rd) } -func (r *selectRegistry) PushBlobChunked(ctx context.Context, repo string, chunkSize int) (ociregistry.BlobWriter, error) { - if !r.allow(repo) { - return nil, ociregistry.ErrDenied +func (r *accessCheckerRegistry) PushBlobChunked(ctx context.Context, repo string, chunkSize int) (ociregistry.BlobWriter, error) { + if err := r.check(repo, AccessWrite); err != nil { + return nil, err } return r.r.PushBlobChunked(ctx, repo, chunkSize) } -func (r *selectRegistry) PushBlobChunkedResume(ctx context.Context, repo, id string, offset int64, chunkSize int) (ociregistry.BlobWriter, error) { - if !r.allow(repo) { - return nil, ociregistry.ErrDenied +func (r *accessCheckerRegistry) PushBlobChunkedResume(ctx context.Context, repo, id string, offset int64, chunkSize int) (ociregistry.BlobWriter, error) { + if err := r.check(repo, AccessWrite); err != nil { + return nil, err } return r.r.PushBlobChunkedResume(ctx, repo, id, offset, chunkSize) } -func (r *selectRegistry) MountBlob(ctx context.Context, fromRepo, toRepo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) { - if !r.allow(toRepo) { - return ociregistry.Descriptor{}, ociregistry.ErrDenied +func (r *accessCheckerRegistry) MountBlob(ctx context.Context, fromRepo, toRepo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) { + if err := r.check(fromRepo, AccessRead); err != nil { + return ociregistry.Descriptor{}, err } - if !r.allow(fromRepo) { - return ociregistry.Descriptor{}, ociregistry.ErrNameUnknown + if err := r.check(toRepo, AccessWrite); err != nil { + return ociregistry.Descriptor{}, err } return r.r.MountBlob(ctx, fromRepo, toRepo, digest) } -func (r *selectRegistry) PushManifest(ctx context.Context, repo string, tag string, contents []byte, mediaType string) (ociregistry.Descriptor, error) { - if !r.allow(repo) { - return ociregistry.Descriptor{}, ociregistry.ErrDenied +func (r *accessCheckerRegistry) PushManifest(ctx context.Context, repo string, tag string, contents []byte, mediaType string) (ociregistry.Descriptor, error) { + if err := r.check(repo, AccessWrite); err != nil { + return ociregistry.Descriptor{}, err } return r.r.PushManifest(ctx, repo, tag, contents, mediaType) } -func (r *selectRegistry) DeleteBlob(ctx context.Context, repo string, digest ociregistry.Digest) error { - if !r.allow(repo) { - return ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) DeleteBlob(ctx context.Context, repo string, digest ociregistry.Digest) error { + if err := r.check(repo, AccessDelete); err != nil { + return err } return r.r.DeleteBlob(ctx, repo, digest) } -func (r *selectRegistry) DeleteManifest(ctx context.Context, repo string, digest ociregistry.Digest) error { - if !r.allow(repo) { - return ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) DeleteManifest(ctx context.Context, repo string, digest ociregistry.Digest) error { + if err := r.check(repo, AccessDelete); err != nil { + return err } return r.r.DeleteManifest(ctx, repo, digest) } -func (r *selectRegistry) DeleteTag(ctx context.Context, repo string, name string) error { - if !r.allow(repo) { - return ociregistry.ErrNameUnknown +func (r *accessCheckerRegistry) DeleteTag(ctx context.Context, repo string, name string) error { + if err := r.check(repo, AccessDelete); err != nil { + return err } return r.r.DeleteTag(ctx, repo, name) } -func (r *selectRegistry) Repositories(ctx context.Context, startAfter string) ociregistry.Seq[string] { +func (r *accessCheckerRegistry) Repositories(ctx context.Context, startAfter string) ociregistry.Seq[string] { + if err := r.check("*", AccessList); err != nil { + return ociregistry.ErrorSeq[string](err) + } return func(yield func(string, error) bool) { // TODO(go1.23): for name, err := range r.r.Repositories(ctx) r.r.Repositories(ctx, startAfter)(func(repo string, err error) bool { @@ -158,7 +204,7 @@ func (r *selectRegistry) Repositories(ctx context.Context, startAfter string) oc yield("", err) return false } - if !r.allow(repo) { + if r.check(repo, AccessRead) != nil { return true } return yield(repo, nil) @@ -166,16 +212,16 @@ func (r *selectRegistry) Repositories(ctx context.Context, startAfter string) oc } } -func (r *selectRegistry) Tags(ctx context.Context, repo, startAfter string) ociregistry.Seq[string] { - if !r.allow(repo) { - return ociregistry.ErrorSeq[string](ociregistry.ErrNameUnknown) +func (r *accessCheckerRegistry) Tags(ctx context.Context, repo, startAfter string) ociregistry.Seq[string] { + if err := r.check(repo, AccessList); err != nil { + return ociregistry.ErrorSeq[string](err) } return r.r.Tags(ctx, repo, startAfter) } -func (r *selectRegistry) Referrers(ctx context.Context, repo string, digest ociregistry.Digest, artifactType string) ociregistry.Seq[ociregistry.Descriptor] { - if !r.allow(repo) { - return ociregistry.ErrorSeq[ociregistry.Descriptor](ociregistry.ErrNameUnknown) +func (r *accessCheckerRegistry) Referrers(ctx context.Context, repo string, digest ociregistry.Digest, artifactType string) ociregistry.Seq[ociregistry.Descriptor] { + if err := r.check(repo, AccessList); err != nil { + return ociregistry.ErrorSeq[ociregistry.Descriptor](err) } return r.r.Referrers(ctx, repo, digest, artifactType) } diff --git a/ociregistry/ocifilter/select_test.go b/ociregistry/ocifilter/select_test.go new file mode 100644 index 00000000..117157fa --- /dev/null +++ b/ociregistry/ocifilter/select_test.go @@ -0,0 +1,210 @@ +// Copyright 2024 CUE Labs AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ocifilter + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/go-quicktest/qt" + "github.com/opencontainers/go-digest" + + "cuelabs.dev/go/oci/ociregistry" + "cuelabs.dev/go/oci/ociregistry/ocimem" +) + +func TestAccessCheckerErrorReturn(t *testing.T) { + ctx := context.Background() + testErr := errors.New("some error") + r1 := AccessChecker(ocimem.New(), func(repoName string, access AccessKind) error { + qt.Check(t, qt.Equals(repoName, "foo/bar")) + qt.Check(t, qt.Equals(access, AccessRead)) + return testErr + }) + _, err := r1.GetTag(ctx, "foo/bar", "t1") + qt.Assert(t, qt.ErrorIs(err, testErr)) +} + +func TestAccessCheckerAccessRequest(t *testing.T) { + assertAccess := func(wantAccess []accessCheck, do func(ctx context.Context, r ociregistry.Interface) error) { + testErr := errors.New("some error") + var gotAccess []accessCheck + r := AccessChecker(&ociregistry.Funcs{ + NewError: func(ctx context.Context, methodName, repo string) error { + return testErr + }, + }, func(repoName string, access AccessKind) error { + gotAccess = append(gotAccess, accessCheck{repoName, access}) + return nil + }) + err := do(context.Background(), r) + qt.Check(t, qt.ErrorIs(err, testErr)) + qt.Check(t, qt.DeepEquals(gotAccess, wantAccess)) + } + assertAccess([]accessCheck{ + {"foo/read", AccessRead}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := r.GetBlob(ctx, "foo/read", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + return err + }) + assertAccess([]accessCheck{ + {"foo/read", AccessRead}, + }, func(ctx context.Context, r ociregistry.Interface) error { + rd, err := r.GetBlobRange(ctx, "foo/read", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 100, 200) + if rd != nil { + rd.Close() + } + return err + }) + + assertAccess([]accessCheck{ + {"foo/read", AccessRead}, + }, func(ctx context.Context, r ociregistry.Interface) error { + rd, err := r.GetManifest(ctx, "foo/read", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + if rd != nil { + rd.Close() + } + return err + }) + + assertAccess([]accessCheck{ + {"foo/read", AccessRead}, + }, func(ctx context.Context, r ociregistry.Interface) error { + rd, err := r.GetTag(ctx, "foo/read", "sometag") + if rd != nil { + rd.Close() + } + return err + }) + + assertAccess([]accessCheck{ + {"foo/read", AccessRead}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := r.ResolveBlob(ctx, "foo/read", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + return err + }) + + assertAccess([]accessCheck{ + {"foo/read", AccessRead}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := r.ResolveManifest(ctx, "foo/read", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + return err + }) + + assertAccess([]accessCheck{ + {"foo/read", AccessRead}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := r.ResolveTag(ctx, "foo/read", "sometag") + return err + }) + + assertAccess([]accessCheck{ + {"foo/write", AccessWrite}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := r.PushBlob(ctx, "foo/write", ociregistry.Descriptor{ + MediaType: "application/json", + Digest: "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + Size: 3, + }, strings.NewReader("foo")) + return err + }) + + assertAccess([]accessCheck{ + {"foo/write", AccessWrite}, + }, func(ctx context.Context, r ociregistry.Interface) error { + w, err := r.PushBlobChunked(ctx, "foo/write", 0) + if err != nil { + return err + } + w.Close() + return nil + }) + + assertAccess([]accessCheck{ + {"foo/write", AccessWrite}, + }, func(ctx context.Context, r ociregistry.Interface) error { + w, err := r.PushBlobChunkedResume(ctx, "foo/write", "/someid", 3, 0) + if err != nil { + return err + } + data := []byte("some data") + if _, err := w.Write(data); err != nil { + return err + } + _, err = w.Commit(digest.FromBytes(data)) + return err + }) + + assertAccess([]accessCheck{ + {"foo/read", AccessRead}, + {"foo/write", AccessWrite}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := r.MountBlob(ctx, "foo/read", "foo/write", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + return err + }) + + assertAccess([]accessCheck{ + {"foo/write", AccessWrite}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := r.PushManifest(ctx, "foo/write", "sometag", []byte("something"), "application/json") + return err + }) + + assertAccess([]accessCheck{ + {"foo/write", AccessDelete}, + }, func(ctx context.Context, r ociregistry.Interface) error { + return r.DeleteBlob(ctx, "foo/write", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + }) + + assertAccess([]accessCheck{ + {"foo/write", AccessDelete}, + }, func(ctx context.Context, r ociregistry.Interface) error { + return r.DeleteManifest(ctx, "foo/write", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + }) + + assertAccess([]accessCheck{ + {"foo/write", AccessDelete}, + }, func(ctx context.Context, r ociregistry.Interface) error { + return r.DeleteTag(ctx, "foo/write", "sometag") + }) + + assertAccess([]accessCheck{ + {"*", AccessList}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := ociregistry.All(r.Repositories(ctx, "")) + return err + }) + + assertAccess([]accessCheck{ + {"foo/read", AccessList}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := ociregistry.All(r.Tags(ctx, "foo/read", "")) + return err + }) + + assertAccess([]accessCheck{ + {"foo/read", AccessList}, + }, func(ctx context.Context, r ociregistry.Interface) error { + _, err := ociregistry.All(r.Referrers(ctx, "foo/read", "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "")) + return err + }) +} + +type accessCheck struct { + Repo string + Check AccessKind +} diff --git a/ociregistry/ocifilter/sub_test.go b/ociregistry/ocifilter/sub_test.go index 1059bcf7..cdbc35c1 100644 --- a/ociregistry/ocifilter/sub_test.go +++ b/ociregistry/ocifilter/sub_test.go @@ -90,15 +90,15 @@ func TestSub(t *testing.T) { }, }, }) - r1 := ocitest.NewRegistry(t, Sub(r.R, "foo")) - desc, err := r1.R.ResolveTag(ctx, "bar", "t1") + r1 := Sub(r.R, "foo") + desc, err := r1.ResolveTag(ctx, "bar", "t1") qt.Assert(t, qt.IsNil(err)) - m := getManifest(t, r1.R, "bar", desc.Digest) - b1Content := getBlob(t, r1.R, "bar", m.Layers[0].Digest) + m := getManifest(t, r1, "bar", desc.Digest) + b1Content := getBlob(t, r1, "bar", m.Layers[0].Digest) qt.Assert(t, qt.Equals(string(b1Content), "hello")) - repos, err := ociregistry.All(r1.R.Repositories(ctx, "")) + repos, err := ociregistry.All(r1.Repositories(ctx, "")) qt.Assert(t, qt.IsNil(err)) sort.Strings(repos) qt.Assert(t, qt.DeepEquals(repos, []string{"bar"}))