From 88fdb279685f6b9d77f39a2615703849fefc7068 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 20 Feb 2025 11:03:48 -0500 Subject: [PATCH] migrate users to in-memory cache --- api/types/user.go | 11 +- lib/cache/cache.go | 63 ------------ lib/cache/cache_test.go | 35 +------ lib/cache/collection.go | 9 ++ lib/cache/legacy_collections.go | 123 ----------------------- lib/cache/users.go | 173 ++++++++++++++++++++++++++++++++ lib/cache/users_test.go | 55 ++++++++++ lib/services/local/events.go | 5 +- 8 files changed, 249 insertions(+), 225 deletions(-) create mode 100644 lib/cache/users.go create mode 100644 lib/cache/users_test.go diff --git a/api/types/user.go b/api/types/user.go index f87fe1958606f..e6360a0471b36 100644 --- a/api/types/user.go +++ b/api/types/user.go @@ -156,6 +156,8 @@ type User interface { SetWeakestDevice(MFADeviceKind) // GetWeakestDevice gets the MFA state for the user. GetWeakestDevice() MFADeviceKind + // Clone creats a copy of the user. + Clone() User } // NewUser creates new empty user @@ -271,14 +273,19 @@ func (u *UserV2) SetName(e string) { u.Metadata.Name = e } +func (u *UserV2) Clone() User { + return utils.CloneProtoMsg(u) +} + // WithoutSecrets returns an instance of resource without secrets. func (u *UserV2) WithoutSecrets() Resource { if u.Spec.LocalAuth == nil { return u } - u2 := *u + + u2 := utils.CloneProtoMsg(u) u2.Spec.LocalAuth = nil - return &u2 + return u2 } // GetTraits gets the trait map for this user used to populate role variables. diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 37a64b2e10074..ed14d1b993332 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -46,7 +46,6 @@ import ( notificationsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/notifications/v1" provisioningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/provisioning/v1" userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2" - userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1" usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" "github.com/gravitational/teleport/api/internalutils/stream" apitracing "github.com/gravitational/teleport/api/observability/tracing" @@ -1084,7 +1083,6 @@ func New(config Config) (*Cache, error) { clusterConfigCache: clusterConfigCache, autoUpdateCache: autoUpdateCache, provisionerCache: local.NewProvisioningService(config.Backend), - usersCache: identityService, accessCache: local.NewAccessService(config.Backend), dynamicAccessCache: local.NewDynamicAccessService(config.Backend), presenceCache: local.NewPresenceService(config.Backend), @@ -2322,67 +2320,6 @@ func (c *Cache) ListRemoteClusters(ctx context.Context, pageSize int, nextToken return remoteClusters, token, trace.Wrap(err) } -// GetUser is a part of auth.Cache implementation. -func (c *Cache) GetUser(ctx context.Context, name string, withSecrets bool) (types.User, error) { - _, span := c.Tracer.Start(ctx, "cache/GetUser") - defer span.End() - - if withSecrets { // cache never tracks user secrets - return c.Config.Users.GetUser(ctx, name, withSecrets) - } - rg, err := readCollectionCache(c, c.legacyCacheCollections.users) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - - user, err := rg.reader.GetUser(ctx, name, withSecrets) - if trace.IsNotFound(err) && rg.IsCacheRead() { - // release read lock early - rg.Release() - // fallback is sane because method is never used - // in construction of derivative caches. - if user, err := c.Config.Users.GetUser(ctx, name, withSecrets); err == nil { - return user, nil - } - } - return user, trace.Wrap(err) -} - -// GetUsers is a part of auth.Cache implementation -func (c *Cache) GetUsers(ctx context.Context, withSecrets bool) ([]types.User, error) { - _, span := c.Tracer.Start(ctx, "cache/GetUsers") - defer span.End() - - if withSecrets { // cache never tracks user secrets - return c.Users.GetUsers(ctx, withSecrets) - } - rg, err := readCollectionCache(c, c.legacyCacheCollections.users) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetUsers(ctx, withSecrets) -} - -// ListUsers returns a page of users. -func (c *Cache) ListUsers(ctx context.Context, req *userspb.ListUsersRequest) (*userspb.ListUsersResponse, error) { - _, span := c.Tracer.Start(ctx, "cache/ListUsers") - defer span.End() - - if req.WithSecrets { // cache never tracks user secrets - rsp, err := c.Users.ListUsers(ctx, req) - return rsp, trace.Wrap(err) - } - rg, err := readCollectionCache(c, c.legacyCacheCollections.users) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - rsp, err := rg.reader.ListUsers(ctx, req) - return rsp, trace.Wrap(err) -} - // GetTunnelConnections is a part of auth.Cache implementation func (c *Cache) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) { _, span := c.Tracer.Start(context.TODO(), "cache/GetTunnelConnections") diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index fa37fed634e3e..cd3f93110deae 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -670,8 +670,6 @@ func TestNodeCAFiltering(t *testing.T) { Events: p.cache, Trust: p.cache.trustCache, ClusterConfig: p.cache.clusterConfigCache, - Provisioner: p.cache.provisionerCache, - Users: p.cache.usersCache, Access: p.cache.accessCache, DynamicAccess: p.cache.dynamicAccessCache, Presence: p.cache.presenceCache, @@ -1600,37 +1598,6 @@ func TestClusterName(t *testing.T) { require.Empty(t, cmp.Diff(outName, clusterName, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) } -// TestUsers tests caching of users -func TestUsers(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForProxy) - t.Cleanup(p.Close) - - testResources(t, p, testFuncs[types.User]{ - newResource: func(name string) (types.User, error) { - return types.NewUser("bob") - }, - create: func(ctx context.Context, user types.User) error { - _, err := p.usersS.UpsertUser(ctx, user) - return err - }, - list: func(ctx context.Context) ([]types.User, error) { - return p.usersS.GetUsers(ctx, false) - }, - cacheList: func(ctx context.Context) ([]types.User, error) { - return p.cache.GetUsers(ctx, false) - }, - update: func(ctx context.Context, user types.User) error { - _, err := p.usersS.UpdateUser(ctx, user) - return err - }, - deleteAll: func(ctx context.Context) error { - return p.usersS.DeleteAllUsers(ctx) - }, - }) -} - // TestRoles tests caching of roles func TestRoles(t *testing.T) { t.Parallel() @@ -3580,7 +3547,7 @@ func TestPartialHealth(t *testing.T) { meta := user.GetMetadata() meta.Labels = map[string]string{"origin": "cache"} user.SetMetadata(meta) - _, err = p.cache.usersCache.UpsertUser(ctx, user) + err = p.cache.collections.users.onUpdate(user) require.NoError(t, err) // the label on the returned user proves that it came from the cache diff --git a/lib/cache/collection.go b/lib/cache/collection.go index 34e4b5a9dabb7..f8d12ef45ae44 100644 --- a/lib/cache/collection.go +++ b/lib/cache/collection.go @@ -154,6 +154,7 @@ type collections struct { staticTokens *collection[types.StaticTokens, *singletonStore[types.StaticTokens], *staticTokensUpstream] certAuthorities *collection[types.CertAuthority, *resourceStore[types.CertAuthority], *caUpstream] + users *collection[types.User, *resourceStore[types.User], *userUpstream] } func setupCollections(c Config, watches []types.WatchKind) (*collections, error) { @@ -181,6 +182,14 @@ func setupCollections(c Config, watches []types.WatchKind) (*collections, error) out.certAuthorities = collect out.byKind[resourceKind] = out.certAuthorities + case types.KindUser: + collect, err := newUserCollection(c.Users, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.users = collect + out.byKind[resourceKind] = out.users } } diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go index 16f0374c41ad8..ade4b3578695c 100644 --- a/lib/cache/legacy_collections.go +++ b/lib/cache/legacy_collections.go @@ -154,7 +154,6 @@ type legacyCollections struct { snowflakeSessions collectionReader[snowflakeSessionGetter] tokens collectionReader[tokenGetter] uiConfigs collectionReader[uiConfigGetter] - users collectionReader[userGetter] userGroups collectionReader[userGroupGetter] userLoginStates collectionReader[services.UserLoginStatesGetter] webSessions collectionReader[webSessionGetter] @@ -259,15 +258,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect watch: watch, } collections.byKind[resourceKind] = collections.uiConfigs - case types.KindUser: - if c.Users == nil { - return nil, trace.BadParameter("missing parameter Users") - } - collections.users = &genericCollection[types.User, userGetter, userExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.users case types.KindRole: if c.Access == nil { return nil, trace.BadParameter("missing parameter Access") @@ -1041,119 +1031,6 @@ type nodeGetter interface { var _ executor[types.Server, nodeGetter] = nodeExecutor{} -type certAuthorityExecutor struct { - // extracted from watch.Filter, to avoid rebuilding on every event - filter types.CertAuthorityFilter -} - -// delete implements executor[types.CertAuthority] -func (certAuthorityExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - err := cache.trustCache.DeleteCertAuthority(ctx, types.CertAuthID{ - Type: types.CertAuthType(resource.GetSubKind()), - DomainName: resource.GetName(), - }) - return trace.Wrap(err) -} - -// deleteAll implements executor[types.CertAuthority] -func (certAuthorityExecutor) deleteAll(ctx context.Context, cache *Cache) error { - for _, caType := range types.CertAuthTypes { - if err := cache.trustCache.DeleteAllCertAuthorities(caType); err != nil { - return trace.Wrap(err) - } - } - return nil -} - -// getAll implements executor[types.CertAuthority] -func (e certAuthorityExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.CertAuthority, error) { - var authorities []types.CertAuthority - for _, caType := range types.CertAuthTypes { - cas, err := cache.Trust.GetCertAuthorities(ctx, caType, loadSecrets) - // if caType was added in this major version we might get a BadParameter - // error if we're connecting to an older upstream that doesn't know about it - if err != nil { - if !(types.IsUnsupportedAuthorityErr(err) && caType.NewlyAdded()) { - return nil, trace.Wrap(err) - } - continue - } - - // this can be removed once we get the ability to fetch CAs with a filter, - // but it should be harmless, and it could be kept as additional safety - if !e.filter.IsEmpty() { - filtered := cas[:0] - for _, ca := range cas { - if e.filter.Match(ca) { - filtered = append(filtered, ca) - } - } - cas = filtered - } - - authorities = append(authorities, cas...) - } - - return authorities, nil -} - -// upsert implements executor[types.CertAuthority] -func (e certAuthorityExecutor) upsert(ctx context.Context, cache *Cache, value types.CertAuthority) error { - if !e.filter.Match(value) { - return nil - } - - return cache.trustCache.UpsertCertAuthority(ctx, value) -} - -func (certAuthorityExecutor) isSingleton() bool { return false } - -func (certAuthorityExecutor) getReader(cache *Cache, cacheOK bool) services.AuthorityGetter { - if cacheOK { - return cache.trustCache - } - return cache.Config.Trust -} - -var _ executor[types.CertAuthority, services.AuthorityGetter] = certAuthorityExecutor{} - -type staticTokensExecutor struct{} - -func (staticTokensExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.StaticTokens, error) { - token, err := cache.ClusterConfig.GetStaticTokens() - if err != nil { - return nil, trace.Wrap(err) - } - return []types.StaticTokens{token}, nil -} - -func (staticTokensExecutor) upsert(ctx context.Context, cache *Cache, resource types.StaticTokens) error { - return cache.clusterConfigCache.SetStaticTokens(resource) -} - -func (staticTokensExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.clusterConfigCache.DeleteStaticTokens() -} - -func (staticTokensExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.clusterConfigCache.DeleteStaticTokens() -} - -func (staticTokensExecutor) isSingleton() bool { return true } - -func (staticTokensExecutor) getReader(cache *Cache, cacheOK bool) staticTokensGetter { - if cacheOK { - return cache.clusterConfigCache - } - return cache.Config.ClusterConfig -} - -type staticTokensGetter interface { - GetStaticTokens() (types.StaticTokens, error) -} - -var _ executor[types.StaticTokens, staticTokensGetter] = staticTokensExecutor{} - type provisionTokenExecutor struct{} func (provisionTokenExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.ProvisionToken, error) { diff --git a/lib/cache/users.go b/lib/cache/users.go new file mode 100644 index 0000000000000..991d051c9e6d3 --- /dev/null +++ b/lib/cache/users.go @@ -0,0 +1,173 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "strings" + + userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/trace" +) + +func newUserCollection(u services.UsersService, w types.WatchKind) (*collection[types.User, *resourceStore[types.User], *userUpstream], error) { + if u == nil { + return nil, trace.BadParameter("missing parameter UsersService") + } + + return &collection[types.User, *resourceStore[types.User], *userUpstream]{ + store: newResourceStore(map[string]func(types.User) string{ + "name": func(u types.User) string { + return u.GetName() + }, + }), + upstream: &userUpstream{UsersService: u}, + watch: w, + }, nil +} + +type userUpstream struct { + services.UsersService +} + +func (c userUpstream) getAll(ctx context.Context, loadSecrets bool) ([]types.User, error) { + return c.UsersService.GetUsers(ctx, loadSecrets) +} + +// GetUser is a part of auth.Cache implementation. +func (c *Cache) GetUser(ctx context.Context, name string, withSecrets bool) (types.User, error) { + _, span := c.Tracer.Start(ctx, "cache/GetUser") + defer span.End() + + if withSecrets { // cache never tracks user secrets + return c.Config.Users.GetUser(ctx, name, withSecrets) + } + + user, err := readCachedResource( + ctx, + c, + c.collections.users, + func(ctx context.Context, store *resourceStore[types.User]) (types.User, error) { + u, err := store.get("name", name) + if err != nil { + // fallback is sane because method is never used + // in construction of derivative caches. + if trace.IsNotFound(err) { + if user, err := c.Config.Users.GetUser(ctx, name, withSecrets); err == nil { + return user, nil + } + } + return nil, trace.Wrap(err) + } + + if withSecrets { + return u.Clone(), nil + } + + return u.WithoutSecrets().(types.User), nil + }, + func(ctx context.Context, upstream *userUpstream) (types.User, error) { + user, err := upstream.GetUser(ctx, name, withSecrets) + return user, trace.Wrap(err) + }) + + return user, trace.Wrap(err) +} + +// GetUsers is a part of auth.Cache implementation +func (c *Cache) GetUsers(ctx context.Context, withSecrets bool) ([]types.User, error) { + _, span := c.Tracer.Start(ctx, "cache/GetUsers") + defer span.End() + + if withSecrets { // cache never tracks user secrets + return c.Users.GetUsers(ctx, withSecrets) + } + + users, err := readCachedResource( + ctx, + c, + c.collections.users, + func(_ context.Context, store *resourceStore[types.User]) ([]types.User, error) { + var users []types.User + for u := range store.iterate("name", "", "") { + if withSecrets { + users = append(users, u.Clone()) + } else { + users = append(users, u.WithoutSecrets().(types.User)) + } + } + + return users, nil + }, + func(ctx context.Context, upstream *userUpstream) ([]types.User, error) { + users, err := upstream.GetUsers(ctx, withSecrets) + return users, trace.Wrap(err) + }) + + return users, trace.Wrap(err) +} + +// ListUsers returns a page of users. +func (c *Cache) ListUsers(ctx context.Context, req *userspb.ListUsersRequest) (*userspb.ListUsersResponse, error) { + _, span := c.Tracer.Start(ctx, "cache/ListUsers") + defer span.End() + + if req.WithSecrets { // cache never tracks user secrets + rsp, err := c.Users.ListUsers(ctx, req) + return rsp, trace.Wrap(err) + } + + users, err := readCachedResource( + ctx, + c, + c.collections.users, + func(_ context.Context, store *resourceStore[types.User]) (*userspb.ListUsersResponse, error) { + var resp userspb.ListUsersResponse + for u := range store.iterate("name", req.PageToken, "") { + uv2, ok := u.(*types.UserV2) + if !ok { + continue + } + + if req.Filter != nil && !req.Filter.Match(uv2) { + continue + } + + if len(resp.Users) == int(req.PageSize) { + key := backend.RangeEnd(backend.ExactKey(u.GetName())).String() + resp.NextPageToken = strings.Trim(key, string(backend.Separator)) + break + } + + if req.WithSecrets { + resp.Users = append(resp.Users, u.Clone().(*types.UserV2)) + } else { + resp.Users = append(resp.Users, u.WithoutSecrets().(*types.UserV2)) + } + } + return &resp, nil + }, + func(ctx context.Context, upstream *userUpstream) (*userspb.ListUsersResponse, error) { + resp, err := upstream.ListUsers(ctx, req) + return resp, trace.Wrap(err) + }) + + return users, trace.Wrap(err) +} diff --git a/lib/cache/users_test.go b/lib/cache/users_test.go new file mode 100644 index 0000000000000..70644a222b0e0 --- /dev/null +++ b/lib/cache/users_test.go @@ -0,0 +1,55 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + + "github.com/gravitational/teleport/api/types" +) + +// TestUsers tests caching of users +func TestUsers(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForProxy) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[types.User]{ + newResource: func(name string) (types.User, error) { + return types.NewUser("bob") + }, + create: func(ctx context.Context, user types.User) error { + _, err := p.usersS.UpsertUser(ctx, user) + return err + }, + list: func(ctx context.Context) ([]types.User, error) { + return p.usersS.GetUsers(ctx, false) + }, + cacheList: func(ctx context.Context) ([]types.User, error) { + return p.cache.GetUsers(ctx, false) + }, + update: func(ctx context.Context, user types.User) error { + _, err := p.usersS.UpdateUser(ctx, user) + return err + }, + deleteAll: func(ctx context.Context) error { + return p.usersS.DeleteAllUsers(ctx) + }, + }) +} diff --git a/lib/services/local/events.go b/lib/services/local/events.go index c7ec41f30cf95..acdb90b3eda8f 100644 --- a/lib/services/local/events.go +++ b/lib/services/local/events.go @@ -1044,12 +1044,11 @@ func (p *userParser) parse(event backend.Event) (types.Resource, error) { return nil, trace.NotFound("failed parsing %v", event.Item.Key.String()) } - return &types.ResourceHeader{ + return &types.UserV2{ Kind: types.KindUser, Version: types.V2, Metadata: types.Metadata{ - Name: strings.TrimPrefix(name, backend.SeparatorString), - Namespace: apidefaults.Namespace, + Name: strings.TrimPrefix(name, backend.SeparatorString), }, }, nil case types.OpPut: