From 9051675df0d9f384507579657570153b579aa77b Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Thu, 20 Feb 2025 15:11:47 -0500 Subject: [PATCH] Reduce resource consumption when generating Kubernetes certificates (#52109) (#52146) Closes https://github.com/gravitational/teleport/issues/52073. The requested Kubernetes cluster is now cross referenced with the KubeServers in the unified resource cache. This results in a reduction in CPU, memory, and cert generation latency. This also cleans up some of the helper functions in lib/kube/utils that were no longer needed, and suboptimal. The client side changes here shouldn't have any impact, as the server is performing the same check, and returning the equivalent error the client side code used to. This will also cut the time of `tctl auth sign` in half as both the client and server were performing the same expensive CheckKubeCluster operation. --- lib/auth/auth.go | 23 +++++- lib/auth/auth_test.go | 107 +++++++++++++++++++++++++- lib/kube/utils/utils.go | 59 -------------- lib/kube/utils/utils_test.go | 82 -------------------- tool/tctl/common/auth_command.go | 20 +---- tool/tctl/common/auth_command_test.go | 104 ------------------------- tool/tsh/common/tsh_test.go | 23 ++++++ 7 files changed, 151 insertions(+), 267 deletions(-) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 48e37ad292198..260b62c8cd3e3 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -102,7 +102,6 @@ import ( "github.com/gravitational/teleport/lib/githubactions" "github.com/gravitational/teleport/lib/gitlab" "github.com/gravitational/teleport/lib/inventory" - kubeutils "github.com/gravitational/teleport/lib/kube/utils" "github.com/gravitational/teleport/lib/kubernetestoken" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/loginrule" @@ -3311,9 +3310,29 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. // If the certificate is targeting a trusted Teleport cluster, it is the // responsibility of the cluster to ensure its existence. if req.routeToCluster == clusterName && req.kubernetesCluster != "" { - if err := kubeutils.CheckKubeCluster(a.closeCtx, a, req.kubernetesCluster); err != nil { + found, _, err := a.UnifiedResourceCache.IterateUnifiedResources(a.closeCtx, func(rwl types.ResourceWithLabels) (bool, error) { + if rwl.GetKind() != types.KindKubeServer { + return false, nil + } + + ks, ok := rwl.(types.KubeServer) + if !ok { + return false, nil + } + + return ks.GetCluster().GetName() == req.kubernetesCluster, nil + }, &proto.ListUnifiedResourcesRequest{ + Kinds: []string{types.KindKubeServer}, + SortBy: types.SortBy{Field: services.SortByName}, + Limit: 1, + }) + if err != nil { return nil, trace.Wrap(err) } + + if len(found) == 0 { + return nil, trace.BadParameter("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'", req.kubernetesCluster) + } } // See which database names and users this user is allowed to use. diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 2bce5e38cfcf2..269ecb2de5cf5 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -149,12 +149,25 @@ func newTestPack( } p.a.SetLockWatcher(lockWatcher) - // set cluster name - err = p.a.SetClusterName(p.clusterName) + urc, err := services.NewUnifiedResourceCache(ctx, services.UnifiedResourceCacheConfig{ + Clock: p.a.clock, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentAuth, + Client: p.a, + }, + ResourceGetter: p.a, + }) if err != nil { return p, trace.Wrap(err) } + p.a.SetUnifiedResourcesCache(urc) + + // set cluster name + if err := p.a.SetClusterName(p.clusterName); err != nil { + return p, trace.Wrap(err) + } + // set static tokens staticTokens, err := types.NewStaticTokens(types.StaticTokensSpecV2{ StaticTokens: []types.ProvisionTokenV1{}, @@ -3004,6 +3017,96 @@ func TestGenerateUserCertWithHardwareKeySupport(t *testing.T) { } } +func TestGenerateKubernetesUserCert(t *testing.T) { + ctx := context.Background() + p, err := newTestPack(ctx, t.TempDir()) + require.NoError(t, err) + + user, _, err := CreateUserAndRole(p.a, "test-user", []string{}, nil) + require.NoError(t, err) + + rc, err := types.NewRemoteCluster("leaf") + require.NoError(t, err) + _, err = p.a.CreateRemoteCluster(ctx, rc) + require.NoError(t, err) + + kubeCluster, err := types.NewKubernetesClusterV3(types.Metadata{Name: "kube-cluster"}, types.KubernetesClusterSpecV3{}) + require.NoError(t, err) + kubeServer, err := types.NewKubernetesServerV3FromCluster(kubeCluster, "foo", "1") + require.NoError(t, err) + _, err = p.a.UpsertKubernetesServer(ctx, kubeServer) + require.NoError(t, err) + + // Wait for cache propagation of the kubernetes resources before proceeding with the tests. + require.EventuallyWithT(t, func(t *assert.CollectT) { + found, _, err := p.a.UnifiedResourceCache.IterateUnifiedResources(ctx, func(rwl types.ResourceWithLabels) (bool, error) { + if rwl.GetKind() != types.KindKubeServer { + return false, nil + } + + ks, ok := rwl.(types.KubeServer) + if !ok { + return false, nil + } + + return ks.GetCluster().GetName() == kubeCluster.GetName(), nil + }, &proto.ListUnifiedResourcesRequest{ + Kinds: []string{types.KindKubeServer}, + SortBy: types.SortBy{Field: services.SortByName}, + Limit: 1, + }) + + assert.NoError(t, err) + assert.Len(t, found, 1) + }, 10*time.Second, 100*time.Millisecond) + + accessInfo := services.AccessInfoFromUserState(user) + accessChecker, err := services.NewAccessChecker(accessInfo, p.clusterName.GetClusterName(), p.a) + require.NoError(t, err) + + _, sshPubKey, _, tlsPubKey := newSSHAndTLSKeyPairs(t) + + for _, tt := range []struct { + name string + teleportCluster string + kubernetesCluster string + assertErr require.ErrorAssertionFunc + }{ + { + name: "leaf clusters not validated", + teleportCluster: "leaf", + kubernetesCluster: "foo", + assertErr: require.NoError, + }, + { + name: "kubernetes cluster not registered", + teleportCluster: p.clusterName.GetClusterName(), + kubernetesCluster: "foo", + assertErr: require.Error, + }, + { + name: "kubernetes cluster registered", + teleportCluster: p.clusterName.GetClusterName(), + kubernetesCluster: kubeCluster.GetName(), + assertErr: require.NoError, + }, + } { + t.Run(tt.name, func(t *testing.T) { + certReq := certRequest{ + user: user, + checker: accessChecker, + sshPublicKey: sshPubKey, + tlsPublicKey: tlsPubKey, + routeToCluster: tt.teleportCluster, + kubernetesCluster: tt.kubernetesCluster, + } + + _, err = p.a.generateUserCert(ctx, certReq) + tt.assertErr(t, err) + }) + } +} + func TestNewWebSession(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/lib/kube/utils/utils.go b/lib/kube/utils/utils.go index f52aad3166301..f300ff12c2c20 100644 --- a/lib/kube/utils/utils.go +++ b/lib/kube/utils/utils.go @@ -22,7 +22,6 @@ import ( "context" "encoding/hex" "errors" - "slices" "strings" "github.com/gravitational/trace" @@ -148,48 +147,6 @@ func EncodeClusterName(clusterName string) string { return "k" + hex.EncodeToString([]byte(clusterName)) } -// KubeServicesPresence fetches a list of registered kubernetes servers. -// It's a subset of services.Presence. -type KubeServicesPresence interface { - // GetKubernetesServers returns a list of registered kubernetes servers. - GetKubernetesServers(context.Context) ([]types.KubeServer, error) -} - -// KubeClusterNames returns a sorted list of unique kubernetes cluster -// names registered in p. -// -// DELETE IN 11.0.0, replaced by ListKubeClustersWithFilters -func KubeClusterNames(ctx context.Context, p KubeServicesPresence) ([]string, error) { - kss, err := p.GetKubernetesServers(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - return extractAndSortKubeClusterNames(kss), nil -} - -func extractAndSortKubeClusterNames(kubeServers []types.KubeServer) []string { - kubeClusters := extractAndSortKubeClusters(kubeServers) - kubeClusterNames := make([]string, len(kubeClusters)) - for i := range kubeClusters { - kubeClusterNames[i] = kubeClusters[i].GetName() - } - - return kubeClusterNames -} - -// KubeClusters returns a sorted list of unique kubernetes clusters -// registered in p. -// -// DELETE IN 11.0.0, replaced by ListKubeClustersWithFilters -func KubeClusters(ctx context.Context, p KubeServicesPresence) ([]types.KubeCluster, error) { - kubeServers, err := p.GetKubernetesServers(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - return extractAndSortKubeClusters(kubeServers), nil -} - // ListKubeClustersWithFilters returns a sorted list of unique kubernetes clusters // registered in p. func ListKubeClustersWithFilters(ctx context.Context, p client.GetResourcesClient, req proto.ListResourcesRequest) ([]types.KubeCluster, error) { @@ -245,19 +202,3 @@ func GetKubeAgentVersion(ctx context.Context, pinger Pinger, clusterFeatures pro return strings.TrimPrefix(agentVersion, "v"), nil } - -// CheckKubeCluster validates kubeClusterName is registered with this Teleport cluster. -func CheckKubeCluster(ctx context.Context, p KubeServicesPresence, kubeClusterName string) error { - if kubeClusterName == "" { - return trace.BadParameter("kube cluster name should not be empty.") - } - kubeClusterNames, err := KubeClusterNames(ctx, p) - if err != nil { - return trace.Wrap(err, "failed to get list of available Kubernetes clusters.") - } - if !slices.Contains(kubeClusterNames, kubeClusterName) { - return trace.BadParameter("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'", kubeClusterName) - } - - return nil -} diff --git a/lib/kube/utils/utils_test.go b/lib/kube/utils/utils_test.go index fb90ebd0a3194..b963c4277c0c4 100644 --- a/lib/kube/utils/utils_test.go +++ b/lib/kube/utils/utils_test.go @@ -26,74 +26,9 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/automaticupgrades" ) -func TestCheckKubeCluster(t *testing.T) { - t.Parallel() - ctx := context.Background() - - kubeServers := []types.KubeServer{ - kubeServer(t, "k8s-1", "server1", "uuuid"), - kubeServer(t, "k8s-2", "server1", "uuuid"), - kubeServer(t, "k8s-3", "server1", "uuuid"), - kubeServer(t, "k8s-4", "server1", "uuuid"), - } - - tests := []struct { - desc string - services []types.KubeServer - kubeCluster string - assertErr require.ErrorAssertionFunc - }{ - { - desc: "valid cluster name", - services: kubeServers, - kubeCluster: "k8s-4", - assertErr: require.NoError, - }, - { - desc: "invalid cluster name", - services: kubeServers, - kubeCluster: "k8s-5", - assertErr: require.Error, - }, - { - desc: "no registered clusters", - services: []types.KubeServer{}, - kubeCluster: "k8s-1", - assertErr: require.Error, - }, - { - desc: "empty cluster provided", - services: kubeServers, - kubeCluster: "", - assertErr: require.Error, - }, - } - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - err := CheckKubeCluster(ctx, mockKubeServicesPresence(tt.services), tt.kubeCluster) - tt.assertErr(t, err) - }) - } -} - -type mockKubeServicesPresence []types.KubeServer - -func (p mockKubeServicesPresence) GetKubernetesServers(context.Context) ([]types.KubeServer, error) { - return p, nil -} - -func kubeServer(t *testing.T, kubeCluster, hostname, hostID string) types.KubeServer { - cluster, err := types.NewKubernetesClusterV3(types.Metadata{Name: kubeCluster}, types.KubernetesClusterSpecV3{}) - require.NoError(t, err) - server, err := types.NewKubernetesServerV3FromCluster(cluster, hostname, hostID) - require.NoError(t, err) - return server -} - func TestGetAgentVersion(t *testing.T) { t.Parallel() @@ -162,20 +97,3 @@ type pinger struct { func (p *pinger) Ping(ctx context.Context) (proto.PingResponse, error) { return p.pingFn(ctx) } - -func TestExtractAndSortKubeClusterNames(t *testing.T) { - t.Parallel() - - server1 := kubeServer(t, "watermelon", "server1", "uuuid") - - server2 := kubeServer(t, "watermelon", "server1", "uuuid") - - server3 := kubeServer(t, "banana", "server2", "uuuid2") - - server4 := kubeServer(t, "apple", "server2", "uuuid2") - - server5 := kubeServer(t, "pear", "server2", "uuuid2") - - names := extractAndSortKubeClusterNames(types.KubeServers{server1, server2, server3, server4, server5}) - require.Equal(t, []string{"apple", "banana", "pear", "watermelon"}, names) -} diff --git a/tool/tctl/common/auth_command.go b/tool/tctl/common/auth_command.go index b11b9643c69be..cddc6c44d0069 100644 --- a/tool/tctl/common/auth_command.go +++ b/tool/tctl/common/auth_command.go @@ -48,7 +48,6 @@ import ( "github.com/gravitational/teleport/lib/client/identityfile" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" - kubeutils "github.com/gravitational/teleport/lib/kube/utils" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -317,7 +316,6 @@ func (a *AuthCommand) GenerateKeys(ctx context.Context, clusterAPI authCommandCl // certificateSigner is an interface for the methods used by GenerateAndSignKeys // to sign certificates using the Auth Server. type certificateSigner interface { - kubeutils.KubeServicesPresence GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) GenerateUserCerts(ctx context.Context, req proto.UserCertsRequest) (*proto.Certs, error) GenerateWindowsDesktopCert(context.Context, *proto.WindowsDesktopCertRequest) (*proto.WindowsDesktopCertResponse, error) @@ -931,7 +929,7 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI certifica } keyRing.ClusterName = a.leafCluster - if err := a.checkKubeCluster(ctx, clusterAPI); err != nil { + if err := a.checkKubeCluster(); err != nil { return trace.Wrap(err) } @@ -1092,7 +1090,7 @@ func (a *AuthCommand) checkLeafCluster(clusterAPI certificateSigner) error { return trace.BadParameter("couldn't find leaf cluster named %q", a.leafCluster) } -func (a *AuthCommand) checkKubeCluster(ctx context.Context, clusterAPI certificateSigner) error { +func (a *AuthCommand) checkKubeCluster() error { if a.kubeCluster == "" { return nil } @@ -1105,20 +1103,6 @@ func (a *AuthCommand) checkKubeCluster(ctx context.Context, clusterAPI certifica return nil } - localCluster, err := clusterAPI.GetClusterName() - if err != nil { - return trace.Wrap(err) - } - if localCluster.GetClusterName() != a.leafCluster { - // Skip validation on remote clusters, since we don't know their - // registered kube clusters. - return nil - } - - if err := kubeutils.CheckKubeCluster(ctx, clusterAPI, a.kubeCluster); err != nil { - return trace.Wrap(err) - } - return nil } diff --git a/tool/tctl/common/auth_command_test.go b/tool/tctl/common/auth_command_test.go index 5840591b0b45c..0e3c958df0feb 100644 --- a/tool/tctl/common/auth_command_test.go +++ b/tool/tctl/common/auth_command_test.go @@ -470,110 +470,6 @@ func (c *mockClient) GenerateCertAuthorityCRL(context.Context, types.CertAuthTyp return c.crl, nil } -func TestCheckKubeCluster(t *testing.T) { - const teleportCluster = "local-teleport" - clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ - ClusterName: teleportCluster, - }) - require.NoError(t, err) - client := &mockClient{ - clusterName: clusterName, - } - tests := []struct { - desc string - kubeCluster string - leafCluster string - outputFormat identityfile.Format - registeredClusters []*types.KubernetesClusterV3 - want string - assertErr require.ErrorAssertionFunc - }{ - { - desc: "non-k8s output format", - outputFormat: identityfile.FormatFile, - assertErr: require.NoError, - }, - { - desc: "local cluster, valid kube cluster", - kubeCluster: "foo", - leafCluster: teleportCluster, - registeredClusters: []*types.KubernetesClusterV3{{Metadata: types.Metadata{Name: "foo"}}}, - outputFormat: identityfile.FormatKubernetes, - want: "foo", - assertErr: require.NoError, - }, - { - desc: "local cluster, empty kube cluster", - kubeCluster: "", - leafCluster: teleportCluster, - registeredClusters: []*types.KubernetesClusterV3{{Metadata: types.Metadata{Name: "foo"}}}, - outputFormat: identityfile.FormatKubernetes, - assertErr: require.NoError, - }, - { - desc: "local cluster, empty kube cluster, no registered kube clusters", - kubeCluster: "", - leafCluster: teleportCluster, - registeredClusters: []*types.KubernetesClusterV3{}, - outputFormat: identityfile.FormatKubernetes, - want: "", - assertErr: require.NoError, - }, - { - desc: "local cluster, invalid kube cluster", - kubeCluster: "bar", - leafCluster: teleportCluster, - registeredClusters: []*types.KubernetesClusterV3{{Metadata: types.Metadata{Name: "foo"}}}, - outputFormat: identityfile.FormatKubernetes, - assertErr: require.Error, - }, - { - desc: "remote cluster, empty kube cluster", - kubeCluster: "", - leafCluster: "remote-teleport", - registeredClusters: []*types.KubernetesClusterV3{{Metadata: types.Metadata{Name: "foo"}}}, - outputFormat: identityfile.FormatKubernetes, - want: "", - assertErr: require.NoError, - }, - { - desc: "remote cluster, non-empty kube cluster", - kubeCluster: "bar", - leafCluster: "remote-teleport", - registeredClusters: []*types.KubernetesClusterV3{{Metadata: types.Metadata{Name: "foo"}}}, - outputFormat: identityfile.FormatKubernetes, - want: "bar", - assertErr: require.NoError, - }, - } - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - client.kubeServers = []types.KubeServer{} - for _, kube := range tt.registeredClusters { - client.kubeServers = append(client.kubeServers, &types.KubernetesServerV3{ - Metadata: types.Metadata{ - Name: kube.GetName(), - }, - Spec: types.KubernetesServerSpecV3{ - Hostname: "host", - Cluster: kube, - }, - }) - } - a := &AuthCommand{ - kubeCluster: tt.kubeCluster, - leafCluster: tt.leafCluster, - outputFormat: tt.outputFormat, - } - err := a.checkKubeCluster(context.Background(), client) - tt.assertErr(t, err) - if err == nil { - require.Equal(t, tt.want, a.kubeCluster) - } - }) - } -} - // TestGenerateDatabaseKeys verifies cert/key pair generation for databases. func TestGenerateDatabaseKeys(t *testing.T) { clusterName, err := services.NewClusterNameWithRandomID( diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 010fb8135dc89..873aacca28822 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -2627,6 +2627,29 @@ func TestKubeCredentialsLock(t *testing.T) { _, err = authServer.UpsertKubernetesServer(context.Background(), kubeServer) require.NoError(t, err) + require.EventuallyWithT(t, func(t *assert.CollectT) { + found, _, err := authServer.UnifiedResourceCache.IterateUnifiedResources(ctx, func(rwl types.ResourceWithLabels) (bool, error) { + if rwl.GetKind() != types.KindKubeServer { + return false, nil + } + + ks, ok := rwl.(types.KubeServer) + if !ok { + return false, nil + } + + return ks.GetCluster().GetName() == kubeCluster.GetName(), nil + }, &proto.ListUnifiedResourcesRequest{ + Kinds: []string{types.KindKubeServer}, + SortBy: types.SortBy{Field: services.SortByName}, + Limit: 1, + }) + + assert.NoError(t, err) + assert.Len(t, found, 1) + + }, 10*time.Second, 100*time.Millisecond) + var ssoCalls atomic.Int32 mockSSOLogin := mockSSOLogin(authServer, alice) mockSSOLoginWithCountCalls := func(ctx context.Context, connectorID string, keyRing *client.KeyRing, protocol string) (*authclient.SSHLoginResponse, error) {