Skip to content

Commit

Permalink
migrate users to in-memory cache
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy committed Feb 20, 2025
1 parent aede314 commit 88fdb27
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 225 deletions.
11 changes: 9 additions & 2 deletions api/types/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ type User interface {
SetWeakestDevice(MFADeviceKind)
// GetWeakestDevice gets the MFA state for the user.
GetWeakestDevice() MFADeviceKind
// Clone creats a copy of the user.
Clone() User
}

// NewUser creates new empty user
Expand Down Expand Up @@ -271,14 +273,19 @@ func (u *UserV2) SetName(e string) {
u.Metadata.Name = e
}

func (u *UserV2) Clone() User {
return utils.CloneProtoMsg(u)
}

// WithoutSecrets returns an instance of resource without secrets.
func (u *UserV2) WithoutSecrets() Resource {
if u.Spec.LocalAuth == nil {
return u
}
u2 := *u

u2 := utils.CloneProtoMsg(u)
u2.Spec.LocalAuth = nil
return &u2
return u2
}

// GetTraits gets the trait map for this user used to populate role variables.
Expand Down
63 changes: 0 additions & 63 deletions lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ import (
notificationsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/notifications/v1"
provisioningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/provisioning/v1"
userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
"github.com/gravitational/teleport/api/internalutils/stream"
apitracing "github.com/gravitational/teleport/api/observability/tracing"
Expand Down Expand Up @@ -1084,7 +1083,6 @@ func New(config Config) (*Cache, error) {
clusterConfigCache: clusterConfigCache,
autoUpdateCache: autoUpdateCache,
provisionerCache: local.NewProvisioningService(config.Backend),
usersCache: identityService,
accessCache: local.NewAccessService(config.Backend),
dynamicAccessCache: local.NewDynamicAccessService(config.Backend),
presenceCache: local.NewPresenceService(config.Backend),
Expand Down Expand Up @@ -2322,67 +2320,6 @@ func (c *Cache) ListRemoteClusters(ctx context.Context, pageSize int, nextToken
return remoteClusters, token, trace.Wrap(err)
}

// GetUser is a part of auth.Cache implementation.
func (c *Cache) GetUser(ctx context.Context, name string, withSecrets bool) (types.User, error) {
_, span := c.Tracer.Start(ctx, "cache/GetUser")
defer span.End()

if withSecrets { // cache never tracks user secrets
return c.Config.Users.GetUser(ctx, name, withSecrets)
}
rg, err := readCollectionCache(c, c.legacyCacheCollections.users)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()

user, err := rg.reader.GetUser(ctx, name, withSecrets)
if trace.IsNotFound(err) && rg.IsCacheRead() {
// release read lock early
rg.Release()
// fallback is sane because method is never used
// in construction of derivative caches.
if user, err := c.Config.Users.GetUser(ctx, name, withSecrets); err == nil {
return user, nil
}
}
return user, trace.Wrap(err)
}

// GetUsers is a part of auth.Cache implementation
func (c *Cache) GetUsers(ctx context.Context, withSecrets bool) ([]types.User, error) {
_, span := c.Tracer.Start(ctx, "cache/GetUsers")
defer span.End()

if withSecrets { // cache never tracks user secrets
return c.Users.GetUsers(ctx, withSecrets)
}
rg, err := readCollectionCache(c, c.legacyCacheCollections.users)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()
return rg.reader.GetUsers(ctx, withSecrets)
}

// ListUsers returns a page of users.
func (c *Cache) ListUsers(ctx context.Context, req *userspb.ListUsersRequest) (*userspb.ListUsersResponse, error) {
_, span := c.Tracer.Start(ctx, "cache/ListUsers")
defer span.End()

if req.WithSecrets { // cache never tracks user secrets
rsp, err := c.Users.ListUsers(ctx, req)
return rsp, trace.Wrap(err)
}
rg, err := readCollectionCache(c, c.legacyCacheCollections.users)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()
rsp, err := rg.reader.ListUsers(ctx, req)
return rsp, trace.Wrap(err)
}

// GetTunnelConnections is a part of auth.Cache implementation
func (c *Cache) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) {
_, span := c.Tracer.Start(context.TODO(), "cache/GetTunnelConnections")
Expand Down
35 changes: 1 addition & 34 deletions lib/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,6 @@ func TestNodeCAFiltering(t *testing.T) {
Events: p.cache,
Trust: p.cache.trustCache,
ClusterConfig: p.cache.clusterConfigCache,
Provisioner: p.cache.provisionerCache,
Users: p.cache.usersCache,
Access: p.cache.accessCache,
DynamicAccess: p.cache.dynamicAccessCache,
Presence: p.cache.presenceCache,
Expand Down Expand Up @@ -1600,37 +1598,6 @@ func TestClusterName(t *testing.T) {
require.Empty(t, cmp.Diff(outName, clusterName, cmpopts.IgnoreFields(types.Metadata{}, "Revision")))
}

// TestUsers tests caching of users
func TestUsers(t *testing.T) {
t.Parallel()

p := newTestPack(t, ForProxy)
t.Cleanup(p.Close)

testResources(t, p, testFuncs[types.User]{
newResource: func(name string) (types.User, error) {
return types.NewUser("bob")
},
create: func(ctx context.Context, user types.User) error {
_, err := p.usersS.UpsertUser(ctx, user)
return err
},
list: func(ctx context.Context) ([]types.User, error) {
return p.usersS.GetUsers(ctx, false)
},
cacheList: func(ctx context.Context) ([]types.User, error) {
return p.cache.GetUsers(ctx, false)
},
update: func(ctx context.Context, user types.User) error {
_, err := p.usersS.UpdateUser(ctx, user)
return err
},
deleteAll: func(ctx context.Context) error {
return p.usersS.DeleteAllUsers(ctx)
},
})
}

// TestRoles tests caching of roles
func TestRoles(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -3580,7 +3547,7 @@ func TestPartialHealth(t *testing.T) {
meta := user.GetMetadata()
meta.Labels = map[string]string{"origin": "cache"}
user.SetMetadata(meta)
_, err = p.cache.usersCache.UpsertUser(ctx, user)
err = p.cache.collections.users.onUpdate(user)
require.NoError(t, err)

// the label on the returned user proves that it came from the cache
Expand Down
9 changes: 9 additions & 0 deletions lib/cache/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ type collections struct {

staticTokens *collection[types.StaticTokens, *singletonStore[types.StaticTokens], *staticTokensUpstream]
certAuthorities *collection[types.CertAuthority, *resourceStore[types.CertAuthority], *caUpstream]
users *collection[types.User, *resourceStore[types.User], *userUpstream]
}

func setupCollections(c Config, watches []types.WatchKind) (*collections, error) {
Expand Down Expand Up @@ -181,6 +182,14 @@ func setupCollections(c Config, watches []types.WatchKind) (*collections, error)

out.certAuthorities = collect
out.byKind[resourceKind] = out.certAuthorities
case types.KindUser:
collect, err := newUserCollection(c.Users, watch)
if err != nil {
return nil, trace.Wrap(err)
}

out.users = collect
out.byKind[resourceKind] = out.users
}
}

Expand Down
123 changes: 0 additions & 123 deletions lib/cache/legacy_collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ type legacyCollections struct {
snowflakeSessions collectionReader[snowflakeSessionGetter]
tokens collectionReader[tokenGetter]
uiConfigs collectionReader[uiConfigGetter]
users collectionReader[userGetter]
userGroups collectionReader[userGroupGetter]
userLoginStates collectionReader[services.UserLoginStatesGetter]
webSessions collectionReader[webSessionGetter]
Expand Down Expand Up @@ -259,15 +258,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect
watch: watch,
}
collections.byKind[resourceKind] = collections.uiConfigs
case types.KindUser:
if c.Users == nil {
return nil, trace.BadParameter("missing parameter Users")
}
collections.users = &genericCollection[types.User, userGetter, userExecutor]{
cache: c,
watch: watch,
}
collections.byKind[resourceKind] = collections.users
case types.KindRole:
if c.Access == nil {
return nil, trace.BadParameter("missing parameter Access")
Expand Down Expand Up @@ -1041,119 +1031,6 @@ type nodeGetter interface {

var _ executor[types.Server, nodeGetter] = nodeExecutor{}

type certAuthorityExecutor struct {
// extracted from watch.Filter, to avoid rebuilding on every event
filter types.CertAuthorityFilter
}

// delete implements executor[types.CertAuthority]
func (certAuthorityExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
err := cache.trustCache.DeleteCertAuthority(ctx, types.CertAuthID{
Type: types.CertAuthType(resource.GetSubKind()),
DomainName: resource.GetName(),
})
return trace.Wrap(err)
}

// deleteAll implements executor[types.CertAuthority]
func (certAuthorityExecutor) deleteAll(ctx context.Context, cache *Cache) error {
for _, caType := range types.CertAuthTypes {
if err := cache.trustCache.DeleteAllCertAuthorities(caType); err != nil {
return trace.Wrap(err)
}
}
return nil
}

// getAll implements executor[types.CertAuthority]
func (e certAuthorityExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.CertAuthority, error) {
var authorities []types.CertAuthority
for _, caType := range types.CertAuthTypes {
cas, err := cache.Trust.GetCertAuthorities(ctx, caType, loadSecrets)
// if caType was added in this major version we might get a BadParameter
// error if we're connecting to an older upstream that doesn't know about it
if err != nil {
if !(types.IsUnsupportedAuthorityErr(err) && caType.NewlyAdded()) {
return nil, trace.Wrap(err)
}
continue
}

// this can be removed once we get the ability to fetch CAs with a filter,
// but it should be harmless, and it could be kept as additional safety
if !e.filter.IsEmpty() {
filtered := cas[:0]
for _, ca := range cas {
if e.filter.Match(ca) {
filtered = append(filtered, ca)
}
}
cas = filtered
}

authorities = append(authorities, cas...)
}

return authorities, nil
}

// upsert implements executor[types.CertAuthority]
func (e certAuthorityExecutor) upsert(ctx context.Context, cache *Cache, value types.CertAuthority) error {
if !e.filter.Match(value) {
return nil
}

return cache.trustCache.UpsertCertAuthority(ctx, value)
}

func (certAuthorityExecutor) isSingleton() bool { return false }

func (certAuthorityExecutor) getReader(cache *Cache, cacheOK bool) services.AuthorityGetter {
if cacheOK {
return cache.trustCache
}
return cache.Config.Trust
}

var _ executor[types.CertAuthority, services.AuthorityGetter] = certAuthorityExecutor{}

type staticTokensExecutor struct{}

func (staticTokensExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.StaticTokens, error) {
token, err := cache.ClusterConfig.GetStaticTokens()
if err != nil {
return nil, trace.Wrap(err)
}
return []types.StaticTokens{token}, nil
}

func (staticTokensExecutor) upsert(ctx context.Context, cache *Cache, resource types.StaticTokens) error {
return cache.clusterConfigCache.SetStaticTokens(resource)
}

func (staticTokensExecutor) deleteAll(ctx context.Context, cache *Cache) error {
return cache.clusterConfigCache.DeleteStaticTokens()
}

func (staticTokensExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
return cache.clusterConfigCache.DeleteStaticTokens()
}

func (staticTokensExecutor) isSingleton() bool { return true }

func (staticTokensExecutor) getReader(cache *Cache, cacheOK bool) staticTokensGetter {
if cacheOK {
return cache.clusterConfigCache
}
return cache.Config.ClusterConfig
}

type staticTokensGetter interface {
GetStaticTokens() (types.StaticTokens, error)
}

var _ executor[types.StaticTokens, staticTokensGetter] = staticTokensExecutor{}

type provisionTokenExecutor struct{}

func (provisionTokenExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.ProvisionToken, error) {
Expand Down
Loading

0 comments on commit 88fdb27

Please sign in to comment.