diff --git a/snapshotter/demux/cache/cache.go b/snapshotter/demux/cache/cache.go index 1250144ff..9d1c3cb4e 100644 --- a/snapshotter/demux/cache/cache.go +++ b/snapshotter/demux/cache/cache.go @@ -16,6 +16,7 @@ package cache import ( "context" + "github.com/containerd/containerd/snapshots" "github.com/firecracker-microvm/firecracker-containerd/snapshotter/demux/proxy" ) @@ -28,6 +29,9 @@ type Cache interface { // fetch function if the snapshotter is not currently cached. Get(ctx context.Context, key string, fetch SnapshotterProvider) (*proxy.RemoteSnapshotter, error) + // WalkAll applies the provided function across all cached snapshotters. + WalkAll(ctx context.Context, fn snapshots.WalkFunc, filters ...string) error + // Closes the snapshotter and removes it from the cache. Evict(key string) error diff --git a/snapshotter/demux/cache/snapshotter_cache.go b/snapshotter/demux/cache/snapshotter_cache.go index 761cbf871..7915a92c9 100644 --- a/snapshotter/demux/cache/snapshotter_cache.go +++ b/snapshotter/demux/cache/snapshotter_cache.go @@ -18,6 +18,7 @@ import ( "fmt" "sync" + "github.com/containerd/containerd/snapshots" "github.com/firecracker-microvm/firecracker-containerd/snapshotter/demux/proxy" "github.com/hashicorp/go-multierror" ) @@ -56,6 +57,21 @@ func (cache *SnapshotterCache) Get(ctx context.Context, key string, fetch Snapsh return snapshotter, nil } +// WalkAll applies the provided function to all cached snapshotters. +func (cache *SnapshotterCache) WalkAll(ctx context.Context, fn snapshots.WalkFunc, filters ...string) error { + cache.mutex.RLock() + defer cache.mutex.RUnlock() + + var allErr error + + for namespace, snapshotter := range cache.snapshotters { + if err := snapshotter.Walk(ctx, fn, filters...); err != nil { + allErr = multierror.Append(allErr, fmt.Errorf("failed to walk function on snapshotter[%s]: %w", namespace, err)) + } + } + return allErr +} + // Evict removes a cached snapshotter for a given key. func (cache *SnapshotterCache) Evict(key string) error { cache.mutex.RLock() diff --git a/snapshotter/demux/cache/snapshotter_cache_test.go b/snapshotter/demux/cache/snapshotter_cache_test.go index 46405eb2f..bb84f4529 100644 --- a/snapshotter/demux/cache/snapshotter_cache_test.go +++ b/snapshotter/demux/cache/snapshotter_cache_test.go @@ -17,6 +17,7 @@ import ( "context" "testing" + "github.com/containerd/containerd/snapshots" "github.com/hashicorp/go-multierror" "github.com/pkg/errors" @@ -55,13 +56,53 @@ func getCachedSnapshotter(uut *SnapshotterCache) error { return nil } -func getSnapshotterPropogatesErrors(uut *SnapshotterCache) error { +func getSnapshotterPropagatesErrors(uut *SnapshotterCache) error { if _, err := uut.Get(context.Background(), "SnapshotterKey", getSnapshotterErrorFunction); err == nil { return errors.New("Get function did not propagate errors from snapshotter generator function") } return nil } +func successfulWalk(ctx context.Context, info snapshots.Info) error { + return nil +} + +func applyWalkFunctionOnEmptyCache(uut *SnapshotterCache) error { + if err := uut.WalkAll(context.Background(), successfulWalk); err != nil { + return errors.New("WalkAll on empty cache incorrectly resulted in error") + } + return nil +} + +func applyWalkFunctionToAllCachedSnapshotters(uut *SnapshotterCache) error { + if _, err := uut.Get(context.Background(), "Snapshotter-A", getSnapshotterOkFunction); err != nil { + return errors.Wrap(err, "Adding snapshotter A to empty cache incorrectly resulted in error") + } + if _, err := uut.Get(context.Background(), "Snapshotter-B", getSnapshotterOkFunction); err != nil { + return errors.Wrap(err, "Adding snapshotter B to empty cache incorrectly resulted in error") + } + if err := uut.WalkAll(context.Background(), successfulWalk); err != nil { + return errors.New("WalkAll on populated cache incorrectly resulted in error") + } + return nil +} + +func applyWalkFunctionPropagatesErrors(uut *SnapshotterCache) error { + if _, err := uut.Get(context.Background(), "Snapshotter-A", getFailingSnapshotterOkFunction); err != nil { + return errors.Wrap(err, "Adding snapshotter A to empty cache incorrectly resulted in error") + } + // The failing snapshotter mock will fail all Walk calls before applying + // the snapshots.WalkFunc, but for the purposes of this test that is fine. + // In which case, any function will do. + walkFunc := func(ctx context.Context, info snapshots.Info) error { + return nil + } + if err := uut.WalkAll(context.Background(), walkFunc); err == nil { + return errors.New("WalkAll did not propagate errors from walk function") + } + return nil +} + func evictSnapshotterFromEmptyCache(uut *SnapshotterCache) error { if err := uut.Evict("SnapshotterKey"); err == nil { return errors.New("Evict function did not return error on call on empty cache") @@ -80,7 +121,7 @@ func evictSnapshotterFromCache(uut *SnapshotterCache) error { return nil } -func evictSnapshotterFromCachePropogatesCloseError(uut *SnapshotterCache) error { +func evictSnapshotterFromCachePropagatesCloseError(uut *SnapshotterCache) error { if _, err := uut.Get(context.Background(), "SnapshotterKey", getFailingSnapshotterOkFunction); err != nil { return errors.Wrap(err, "Adding snapshotter to empty cache incorrectly resulted in error") } @@ -156,14 +197,36 @@ func TestGetSnapshotterFromCache(t *testing.T) { }{ {"AddSnapshotterToCache", getSnapshotterFromEmptyCache}, {"GetCachedSnapshotter", getCachedSnapshotter}, - {"PropogateFetchSnapshotterErrors", getSnapshotterPropogatesErrors}, + {"PropogateFetchSnapshotterErrors", getSnapshotterPropagatesErrors}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + uut := NewSnapshotterCache() + if err := test.run(uut); err != nil { + t.Fatalf("%s: %s", test.name, err.Error()) + } + }) + } +} + +func TestWalkAllFunctionOnCache(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + run func(*SnapshotterCache) error + }{ + {"ApplyWalkFunctionOnEmptyCache", applyWalkFunctionOnEmptyCache}, + {"ApplyWalkFunctionToAllCachedSnapshotters", applyWalkFunctionToAllCachedSnapshotters}, + {"ApplyWalkFunctionPropogatesErrors", applyWalkFunctionPropagatesErrors}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { uut := NewSnapshotterCache() if err := test.run(uut); err != nil { - t.Fatal(test.name + ": " + err.Error()) + t.Fatalf("%s: %s", test.name, err.Error()) } }) } @@ -178,14 +241,14 @@ func TestEvictSnapshotterFromCache(t *testing.T) { }{ {"EvictSnapshotterFromEmptyCache", evictSnapshotterFromEmptyCache}, {"EvictSnapshotterFromCache", evictSnapshotterFromCache}, - {"PropogateEvictSnapshotterCloseErrors", evictSnapshotterFromCachePropogatesCloseError}, + {"PropogateEvictSnapshotterCloseErrors", evictSnapshotterFromCachePropagatesCloseError}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { uut := NewSnapshotterCache() if err := test.run(uut); err != nil { - t.Fatal(test.name + ": " + err.Error()) + t.Fatalf("%s: %s", test.name, err.Error()) } }) } @@ -207,7 +270,7 @@ func TestCloseCache(t *testing.T) { t.Run(test.name, func(t *testing.T) { uut := NewSnapshotterCache() if err := test.run(uut); err != nil { - t.Fatal(test.name + ": " + err.Error()) + t.Fatalf("%s: %s", test.name, err.Error()) } }) } @@ -228,7 +291,7 @@ func TestListSnapshotters(t *testing.T) { t.Run(test.name, func(t *testing.T) { uut := NewSnapshotterCache() if err := test.run(uut); err != nil { - t.Fatal(test.name + ": " + err.Error()) + t.Fatalf("%s: %s", test.name, err.Error()) } }) } diff --git a/snapshotter/demux/snapshotter.go b/snapshotter/demux/snapshotter.go index 345f6fc48..b1e55d693 100644 --- a/snapshotter/demux/snapshotter.go +++ b/snapshotter/demux/snapshotter.go @@ -243,7 +243,8 @@ func (s *Snapshotter) Walk(ctx context.Context, fn snapshots.WalkFunc, filters . contextLogger := log.G(ctx).WithField("function", "Walk") namespace, err := getNamespaceFromContext(ctx, contextLogger) if err != nil { - return err + contextLogger.Debug("no namespace found, proxying walk function to all cached snapshotters") + return s.cache.WalkAll(ctx, fn, filters...) } logger := contextLogger.WithField("namespace", namespace) diff --git a/snapshotter/demux/snapshotter_test.go b/snapshotter/demux/snapshotter_test.go index 70d20bd81..2202e1ac7 100644 --- a/snapshotter/demux/snapshotter_test.go +++ b/snapshotter/demux/snapshotter_test.go @@ -73,6 +73,27 @@ func TestReturnErrorWhenCalledWithoutNamespacedContext(t *testing.T) { {"View", func() error { _, err := uut.View(ctx, "layerKey", ""); return err }}, {"Commit", func() error { return uut.Commit(ctx, "layer1", "layerKey") }}, {"Remove", func() error { return uut.Remove(ctx, "layerKey") }}, + } + + for _, test := range tests { + if err := test.run(); err == nil { + t.Fatalf("%s call did not return error", test.name) + } + } +} + +func TestNoErrorWhenCalledWithoutNamespacedContext(t *testing.T) { + t.Parallel() + + cache := cache.NewSnapshotterCache() + ctx := logtest.WithT(context.Background(), t) + + uut := NewSnapshotter(cache, fetchOkSnapshotter) + + tests := []struct { + name string + run func() error + }{ {"Walk", func() error { var callback = func(c context.Context, i snapshots.Info) error { return nil } return uut.Walk(ctx, callback) @@ -80,8 +101,8 @@ func TestReturnErrorWhenCalledWithoutNamespacedContext(t *testing.T) { } for _, test := range tests { - if err := test.run(); err == nil { - t.Fatal(test.name + " call did not return error") + if err := test.run(); err != nil { + t.Fatalf("%s call returned error on no namespace execution", test.name) } } } @@ -116,7 +137,7 @@ func TestReturnErrorWhenSnapshotterNotFound(t *testing.T) { for _, test := range tests { if err := test.run(); err == nil { - t.Fatal(test.name + " call did not return error") + t.Fatalf("%s call did not return error", test.name) } } } @@ -153,7 +174,7 @@ func TestReturnErrorAfterProxyFunctionFailure(t *testing.T) { for _, test := range tests { t.Run(test.name+"ProxyFailure", func(t *testing.T) { if err := test.run(); err == nil { - t.Fatal(test.name + " call did not return error") + t.Fatalf("%s call did not return error", test.name) } }) } @@ -191,7 +212,7 @@ func TestNoErrorIsReturnedOnSuccessfulProxyExecution(t *testing.T) { for _, test := range tests { t.Run(test.name+"SuccessfulProxyCall", func(t *testing.T) { if err := test.run(); err != nil { - t.Fatal(test.name + " call incorrectly returned an error") + t.Fatalf("%s call incorrectly returned an error", test.name) } }) }