From d792cbe82e1b96a77b6ae839c3e576a5ca077ff5 Mon Sep 17 00:00:00 2001 From: Zac Bergquist Date: Tue, 21 Jan 2025 07:33:42 -0700 Subject: [PATCH] Improve the tests for listing active sessions (#51246) Use a more standard table-driven test setup, as the previous approach put the description of the test at the very end. Also adds a bit more flexibility to allow for additional more complicated tests to be added, and adds one additional test case to verify that explicit deny rules work. --- lib/auth/auth_with_roles_test.go | 266 ++++++++++++++++++------------- 1 file changed, 157 insertions(+), 109 deletions(-) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 923e74f46f4e7..9fa89ab1b1301 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -6610,142 +6610,190 @@ func TestLocalServiceRolesHavePermissionsForUploaderService(t *testing.T) { } } -type getActiveSessionsTestCase struct { - name string - tracker types.SessionTracker - role types.Role - hasAccess bool -} - func TestGetActiveSessionTrackers(t *testing.T) { t.Parallel() - testCases := []getActiveSessionsTestCase{func() getActiveSessionsTestCase { - tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ - SessionID: "1", - Kind: string(types.SSHSessionKind), - }) - require.NoError(t, err) - - role, err := types.NewRole("foo", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{{ - Resources: []string{types.KindSessionTracker}, - Verbs: []string{types.VerbList, types.VerbRead}, - }}, - }, - }) - require.NoError(t, err) - - return getActiveSessionsTestCase{"with access simple", tracker, role, true} - }(), func() getActiveSessionsTestCase { - tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ - SessionID: "1", - Kind: string(types.SSHSessionKind), - }) - require.NoError(t, err) - - role, err := types.NewRole("foo", types.RoleSpecV6{}) - require.NoError(t, err) + type activeSessionsTestCase struct { + name string + makeRole func() (types.Role, error) + makeTracker func(testUser types.User) (types.SessionTracker, error) + extraSetup func(*testing.T, *TestTLSServer) - return getActiveSessionsTestCase{"with no access rule", tracker, role, false} - }(), func() getActiveSessionsTestCase { - tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ - SessionID: "1", - Kind: string(types.KubernetesSessionKind), - }) - require.NoError(t, err) + checkSessionTrackers require.ValueAssertionFunc + } - role, err := types.NewRole("foo", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{{ - Resources: []string{types.KindSessionTracker}, - Verbs: []string{types.VerbList, types.VerbRead}, - Where: "equals(session_tracker.session_id, \"1\")", - }}, + for _, tc := range []activeSessionsTestCase{ + { + name: "simple-access", + makeRole: func() (types.Role, error) { + return types.NewRole("foo", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{{ + Resources: []string{types.KindSessionTracker}, + Verbs: []string{types.VerbList, types.VerbRead}, + }}, + }, + }) }, - }) - require.NoError(t, err) - - return getActiveSessionsTestCase{"access with match expression", tracker, role, true} - }(), func() getActiveSessionsTestCase { - tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ - SessionID: "2", - Kind: string(types.KubernetesSessionKind), - }) - require.NoError(t, err) - - role, err := types.NewRole("foo", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{{ - Resources: []string{types.KindSessionTracker}, - Verbs: []string{types.VerbList, types.VerbRead}, - Where: "equals(session_tracker.session_id, \"1\")", - }}, + makeTracker: func(testUser types.User) (types.SessionTracker, error) { + return types.NewSessionTracker(types.SessionTrackerSpecV1{ + SessionID: "1", + Kind: string(types.SSHSessionKind), + }) }, - }) - require.NoError(t, err) - - return getActiveSessionsTestCase{"no access with match expression", tracker, role, false} - }(), func() getActiveSessionsTestCase { - tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ - SessionID: "1", - Kind: string(types.SSHSessionKind), - }) - require.NoError(t, err) - - role, err := types.NewRoleWithVersion("dev", types.V3, types.RoleSpecV6{ - Allow: types.RoleConditions{ - AppLabels: types.Labels{"*": []string{"*"}}, - DatabaseLabels: types.Labels{"*": []string{"*"}}, - KubernetesLabels: types.Labels{"*": []string{"*"}}, - KubernetesResources: []types.KubernetesResource{ - {Kind: types.KindKubePod, Name: "*", Namespace: "*", Verbs: []string{"*"}}, - }, - NodeLabels: types.Labels{"*": []string{"*"}}, - NodeLabelsExpression: `contains(user.spec.traits["cluster_ids"], labels["cluster_id"]) || contains(user.spec.traits["sub"], labels["owner"])`, - Logins: []string{"{{external.sub}}"}, - WindowsDesktopLabels: types.Labels{"cluster_id": []string{"{{external.cluster_ids}}"}}, - WindowsDesktopLogins: []string{"{{external.sub}}", "{{external.windows_logins}}"}, + checkSessionTrackers: require.NotEmpty, + }, + { + name: "no-access-rule", + makeRole: func() (types.Role, error) { + return types.NewRole("foo", types.RoleSpecV6{}) }, - Deny: types.RoleConditions{ - Rules: []types.Rule{ - { - Resources: []string{types.KindDatabaseServer, types.KindAppServer, types.KindSession, types.KindSSHSession, types.KindKubeService, types.KindSessionTracker}, - Verbs: []string{"list", "read"}, + makeTracker: func(testUser types.User) (types.SessionTracker, error) { + return types.NewSessionTracker(types.SessionTrackerSpecV1{ + SessionID: "1", + Kind: string(types.SSHSessionKind), + }) + }, + checkSessionTrackers: require.Empty, + }, + { + name: "access-with-match-expression", + makeRole: func() (types.Role, error) { + return types.NewRole("foo", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{{ + Resources: []string{types.KindSessionTracker}, + Verbs: []string{types.VerbList, types.VerbRead}, + Where: "equals(session_tracker.session_id, \"1\")", + }}, }, - }, + }) }, - }) - require.NoError(t, err) - - return getActiveSessionsTestCase{"filter bug v3 role", tracker, role, false} - }(), - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { + makeTracker: func(testUser types.User) (types.SessionTracker, error) { + return types.NewSessionTracker(types.SessionTrackerSpecV1{ + SessionID: "1", + Kind: string(types.SSHSessionKind), + }) + }, + checkSessionTrackers: require.NotEmpty, + }, + { + name: "no-access-with-match-expression", + makeRole: func() (types.Role, error) { + return types.NewRole("foo", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{{ + Resources: []string{types.KindSessionTracker}, + Verbs: []string{types.VerbList, types.VerbRead}, + Where: "equals(session_tracker.session_id, \"1\")", + }}, + }, + }) + }, + makeTracker: func(testUser types.User) (types.SessionTracker, error) { + return types.NewSessionTracker(types.SessionTrackerSpecV1{ + SessionID: "2", + Kind: string(types.KubernetesSessionKind), + }) + }, + checkSessionTrackers: require.Empty, + }, + { + name: "filter-bug-v3-role", + makeRole: func() (types.Role, error) { + return types.NewRoleWithVersion("dev", types.V3, types.RoleSpecV6{ + Allow: types.RoleConditions{ + AppLabels: types.Labels{"*": []string{"*"}}, + DatabaseLabels: types.Labels{"*": []string{"*"}}, + KubernetesLabels: types.Labels{"*": []string{"*"}}, + KubernetesResources: []types.KubernetesResource{ + {Kind: types.KindKubePod, Name: "*", Namespace: "*", Verbs: []string{"*"}}, + }, + NodeLabels: types.Labels{"*": []string{"*"}}, + NodeLabelsExpression: `contains(user.spec.traits["cluster_ids"], labels["cluster_id"]) || contains(user.spec.traits["sub"], labels["owner"])`, + Logins: []string{"{{external.sub}}"}, + WindowsDesktopLabels: types.Labels{"cluster_id": []string{"{{external.cluster_ids}}"}}, + WindowsDesktopLogins: []string{"{{external.sub}}", "{{external.windows_logins}}"}, + }, + Deny: types.RoleConditions{ + Rules: []types.Rule{ + { + Resources: []string{types.KindDatabaseServer, types.KindAppServer, types.KindSession, types.KindSSHSession, types.KindKubeService, types.KindSessionTracker}, + Verbs: []string{"list", "read"}, + }, + }, + }, + }) + }, + makeTracker: func(testUser types.User) (types.SessionTracker, error) { + return types.NewSessionTracker(types.SessionTrackerSpecV1{ + SessionID: "1", + Kind: string(types.SSHSessionKind), + }) + }, + checkSessionTrackers: require.Empty, + }, + { + name: "explicit-deny-wins", // so long as the user doesn't have join permissions + makeRole: func() (types.Role, error) { + return types.NewRole("foo", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{{ + Resources: []string{types.KindSessionTracker}, + Verbs: []string{types.VerbList, types.VerbRead}, + }}, + }, + Deny: types.RoleConditions{ + Rules: []types.Rule{{ + Resources: []string{types.KindSessionTracker}, + Verbs: []string{types.VerbList, types.VerbRead}, + }}, + }, + }) + }, + makeTracker: func(testUser types.User) (types.SessionTracker, error) { + return types.NewSessionTracker(types.SessionTrackerSpecV1{ + SessionID: "1", + Kind: string(types.SSHSessionKind), + }) + }, + checkSessionTrackers: require.Empty, + }, + } { + t.Run(tc.name, func(t *testing.T) { ctx := context.Background() srv := newTestTLSServer(t) - _, err := srv.Auth().CreateRole(ctx, testCase.role) + + role, err := tc.makeRole() require.NoError(t, err) - _, err = srv.Auth().CreateSessionTracker(ctx, testCase.tracker) + _, err = srv.Auth().CreateRole(ctx, role) require.NoError(t, err) user, err := types.NewUser(uuid.NewString()) require.NoError(t, err) - user.AddRole(testCase.role.GetName()) + user.AddRole(role.GetName()) user, err = srv.Auth().UpsertUser(ctx, user) require.NoError(t, err) + if tc.extraSetup != nil { + tc.extraSetup(t, srv) + } + + tracker, err := tc.makeTracker(user) + require.NoError(t, err) + + _, err = srv.Auth().CreateSessionTracker(ctx, tracker) + require.NoError(t, err) + clt, err := srv.NewClient(TestUser(user.GetName())) require.NoError(t, err) found, err := clt.GetActiveSessionTrackers(ctx) require.NoError(t, err) - require.Equal(t, testCase.hasAccess, len(found) != 0) + + tc.checkSessionTrackers(t, found) }) } }