diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
index 0cb29695f968b..ff6cdd143e7e1 100644
--- a/.github/workflows/lint.yaml
+++ b/.github/workflows/lint.yaml
@@ -235,7 +235,7 @@ jobs:
- name: Check if Terraform resources are up to date
# We have to add the current directory as a safe directory or else git commands will not work as expected.
# The protoc-gen-terraform version must match the version in integrations/terraform/Makefile
- run: git config --global --add safe.directory $(realpath .) && go install github.com/gravitational/protoc-gen-terraform@c91cc3ef4d7d0046c36cb96b1cd337e466c61225 && make terraform-resources-up-to-date
+ run: git config --global --add safe.directory $(realpath .) && go install github.com/gravitational/protoc-gen-terraform/v3@v3.0.2 && make terraform-resources-up-to-date
lint-rfd:
name: Lint (RFD)
diff --git a/.golangci.yml b/.golangci.yml
index 98859bad6c7d9..ecc5e7c8e253f 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -115,14 +115,6 @@ linters-settings:
desc: 'use "crypto" or "x/crypto" instead'
# Prevent importing any additional logging libraries.
logging:
- files:
- # Integrations are still allowed to use logrus becuase they haven't
- # been converted to slog yet. Once they use slog, remove this exception.
- - '!**/integrations/**'
- # The log package still contains the logrus formatter consumed by the integrations.
- # Remove this exception when said formatter is deleted.
- - '!**/lib/utils/log/**'
- - '!**/lib/utils/cli.go'
deny:
- pkg: github.com/sirupsen/logrus
desc: 'use "log/slog" instead'
diff --git a/api/gen/proto/go/teleport/workloadidentity/v1/resource.pb.go b/api/gen/proto/go/teleport/workloadidentity/v1/resource.pb.go
index 1849d7e902173..fa758941db455 100644
--- a/api/gen/proto/go/teleport/workloadidentity/v1/resource.pb.go
+++ b/api/gen/proto/go/teleport/workloadidentity/v1/resource.pb.go
@@ -121,20 +121,209 @@ func (x *WorkloadIdentity) GetSpec() *WorkloadIdentitySpec {
return nil
}
+// The attribute casted to a string must be equal to the value.
+type WorkloadIdentityConditionEq struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // The value to compare the attribute against.
+ Value string `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *WorkloadIdentityConditionEq) Reset() {
+ *x = WorkloadIdentityConditionEq{}
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[1]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *WorkloadIdentityConditionEq) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*WorkloadIdentityConditionEq) ProtoMessage() {}
+
+func (x *WorkloadIdentityConditionEq) ProtoReflect() protoreflect.Message {
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[1]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use WorkloadIdentityConditionEq.ProtoReflect.Descriptor instead.
+func (*WorkloadIdentityConditionEq) Descriptor() ([]byte, []int) {
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{1}
+}
+
+func (x *WorkloadIdentityConditionEq) GetValue() string {
+ if x != nil {
+ return x.Value
+ }
+ return ""
+}
+
+// The attribute casted to a string must not be equal to the value.
+type WorkloadIdentityConditionNotEq struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // The value to compare the attribute against.
+ Value string `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *WorkloadIdentityConditionNotEq) Reset() {
+ *x = WorkloadIdentityConditionNotEq{}
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[2]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *WorkloadIdentityConditionNotEq) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*WorkloadIdentityConditionNotEq) ProtoMessage() {}
+
+func (x *WorkloadIdentityConditionNotEq) ProtoReflect() protoreflect.Message {
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[2]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use WorkloadIdentityConditionNotEq.ProtoReflect.Descriptor instead.
+func (*WorkloadIdentityConditionNotEq) Descriptor() ([]byte, []int) {
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{2}
+}
+
+func (x *WorkloadIdentityConditionNotEq) GetValue() string {
+ if x != nil {
+ return x.Value
+ }
+ return ""
+}
+
+// The attribute casted to a string must be in the list of values.
+type WorkloadIdentityConditionIn struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // The list of values to compare the attribute against.
+ Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *WorkloadIdentityConditionIn) Reset() {
+ *x = WorkloadIdentityConditionIn{}
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[3]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *WorkloadIdentityConditionIn) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*WorkloadIdentityConditionIn) ProtoMessage() {}
+
+func (x *WorkloadIdentityConditionIn) ProtoReflect() protoreflect.Message {
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[3]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use WorkloadIdentityConditionIn.ProtoReflect.Descriptor instead.
+func (*WorkloadIdentityConditionIn) Descriptor() ([]byte, []int) {
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{3}
+}
+
+func (x *WorkloadIdentityConditionIn) GetValues() []string {
+ if x != nil {
+ return x.Values
+ }
+ return nil
+}
+
+// The attribute casted to a string must not be in the list of values.
+type WorkloadIdentityConditionNotIn struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // The list of values to compare the attribute against.
+ Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *WorkloadIdentityConditionNotIn) Reset() {
+ *x = WorkloadIdentityConditionNotIn{}
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[4]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *WorkloadIdentityConditionNotIn) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*WorkloadIdentityConditionNotIn) ProtoMessage() {}
+
+func (x *WorkloadIdentityConditionNotIn) ProtoReflect() protoreflect.Message {
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[4]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use WorkloadIdentityConditionNotIn.ProtoReflect.Descriptor instead.
+func (*WorkloadIdentityConditionNotIn) Descriptor() ([]byte, []int) {
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{4}
+}
+
+func (x *WorkloadIdentityConditionNotIn) GetValues() []string {
+ if x != nil {
+ return x.Values
+ }
+ return nil
+}
+
// The individual conditions that make up a rule.
type WorkloadIdentityCondition struct {
state protoimpl.MessageState `protogen:"open.v1"`
// The name of the attribute to evaluate the condition against.
Attribute string `protobuf:"bytes,1,opt,name=attribute,proto3" json:"attribute,omitempty"`
- // An exact string that the attribute must match.
- Equals string `protobuf:"bytes,2,opt,name=equals,proto3" json:"equals,omitempty"`
+ // Types that are valid to be assigned to Operator:
+ //
+ // *WorkloadIdentityCondition_Eq
+ // *WorkloadIdentityCondition_NotEq
+ // *WorkloadIdentityCondition_In
+ // *WorkloadIdentityCondition_NotIn
+ Operator isWorkloadIdentityCondition_Operator `protobuf_oneof:"operator"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *WorkloadIdentityCondition) Reset() {
*x = WorkloadIdentityCondition{}
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[1]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -146,7 +335,7 @@ func (x *WorkloadIdentityCondition) String() string {
func (*WorkloadIdentityCondition) ProtoMessage() {}
func (x *WorkloadIdentityCondition) ProtoReflect() protoreflect.Message {
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[1]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[5]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -159,7 +348,7 @@ func (x *WorkloadIdentityCondition) ProtoReflect() protoreflect.Message {
// Deprecated: Use WorkloadIdentityCondition.ProtoReflect.Descriptor instead.
func (*WorkloadIdentityCondition) Descriptor() ([]byte, []int) {
- return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{1}
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{5}
}
func (x *WorkloadIdentityCondition) GetAttribute() string {
@@ -169,13 +358,81 @@ func (x *WorkloadIdentityCondition) GetAttribute() string {
return ""
}
-func (x *WorkloadIdentityCondition) GetEquals() string {
+func (x *WorkloadIdentityCondition) GetOperator() isWorkloadIdentityCondition_Operator {
if x != nil {
- return x.Equals
+ return x.Operator
}
- return ""
+ return nil
+}
+
+func (x *WorkloadIdentityCondition) GetEq() *WorkloadIdentityConditionEq {
+ if x != nil {
+ if x, ok := x.Operator.(*WorkloadIdentityCondition_Eq); ok {
+ return x.Eq
+ }
+ }
+ return nil
+}
+
+func (x *WorkloadIdentityCondition) GetNotEq() *WorkloadIdentityConditionNotEq {
+ if x != nil {
+ if x, ok := x.Operator.(*WorkloadIdentityCondition_NotEq); ok {
+ return x.NotEq
+ }
+ }
+ return nil
}
+func (x *WorkloadIdentityCondition) GetIn() *WorkloadIdentityConditionIn {
+ if x != nil {
+ if x, ok := x.Operator.(*WorkloadIdentityCondition_In); ok {
+ return x.In
+ }
+ }
+ return nil
+}
+
+func (x *WorkloadIdentityCondition) GetNotIn() *WorkloadIdentityConditionNotIn {
+ if x != nil {
+ if x, ok := x.Operator.(*WorkloadIdentityCondition_NotIn); ok {
+ return x.NotIn
+ }
+ }
+ return nil
+}
+
+type isWorkloadIdentityCondition_Operator interface {
+ isWorkloadIdentityCondition_Operator()
+}
+
+type WorkloadIdentityCondition_Eq struct {
+ // The attribute casted to a string must be equal to the value.
+ Eq *WorkloadIdentityConditionEq `protobuf:"bytes,3,opt,name=eq,proto3,oneof"`
+}
+
+type WorkloadIdentityCondition_NotEq struct {
+ // The attribute casted to a string must not be equal to the value.
+ NotEq *WorkloadIdentityConditionNotEq `protobuf:"bytes,4,opt,name=not_eq,json=notEq,proto3,oneof"`
+}
+
+type WorkloadIdentityCondition_In struct {
+ // The attribute casted to a string must be in the list of values.
+ In *WorkloadIdentityConditionIn `protobuf:"bytes,5,opt,name=in,proto3,oneof"`
+}
+
+type WorkloadIdentityCondition_NotIn struct {
+ // The attribute casted to a string must not be in the list of values.
+ NotIn *WorkloadIdentityConditionNotIn `protobuf:"bytes,6,opt,name=not_in,json=notIn,proto3,oneof"`
+}
+
+func (*WorkloadIdentityCondition_Eq) isWorkloadIdentityCondition_Operator() {}
+
+func (*WorkloadIdentityCondition_NotEq) isWorkloadIdentityCondition_Operator() {}
+
+func (*WorkloadIdentityCondition_In) isWorkloadIdentityCondition_Operator() {}
+
+func (*WorkloadIdentityCondition_NotIn) isWorkloadIdentityCondition_Operator() {}
+
// An individual rule that is evaluated during the issuance of a WorkloadIdentity.
type WorkloadIdentityRule struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -187,7 +444,7 @@ type WorkloadIdentityRule struct {
func (x *WorkloadIdentityRule) Reset() {
*x = WorkloadIdentityRule{}
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[2]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -199,7 +456,7 @@ func (x *WorkloadIdentityRule) String() string {
func (*WorkloadIdentityRule) ProtoMessage() {}
func (x *WorkloadIdentityRule) ProtoReflect() protoreflect.Message {
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[2]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[6]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -212,7 +469,7 @@ func (x *WorkloadIdentityRule) ProtoReflect() protoreflect.Message {
// Deprecated: Use WorkloadIdentityRule.ProtoReflect.Descriptor instead.
func (*WorkloadIdentityRule) Descriptor() ([]byte, []int) {
- return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{2}
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{6}
}
func (x *WorkloadIdentityRule) GetConditions() []*WorkloadIdentityCondition {
@@ -235,7 +492,7 @@ type WorkloadIdentityRules struct {
func (x *WorkloadIdentityRules) Reset() {
*x = WorkloadIdentityRules{}
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[3]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -247,7 +504,7 @@ func (x *WorkloadIdentityRules) String() string {
func (*WorkloadIdentityRules) ProtoMessage() {}
func (x *WorkloadIdentityRules) ProtoReflect() protoreflect.Message {
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[3]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[7]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -260,7 +517,7 @@ func (x *WorkloadIdentityRules) ProtoReflect() protoreflect.Message {
// Deprecated: Use WorkloadIdentityRules.ProtoReflect.Descriptor instead.
func (*WorkloadIdentityRules) Descriptor() ([]byte, []int) {
- return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{3}
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{7}
}
func (x *WorkloadIdentityRules) GetAllow() []*WorkloadIdentityRule {
@@ -284,7 +541,7 @@ type WorkloadIdentitySPIFFEX509 struct {
func (x *WorkloadIdentitySPIFFEX509) Reset() {
*x = WorkloadIdentitySPIFFEX509{}
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[4]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -296,7 +553,7 @@ func (x *WorkloadIdentitySPIFFEX509) String() string {
func (*WorkloadIdentitySPIFFEX509) ProtoMessage() {}
func (x *WorkloadIdentitySPIFFEX509) ProtoReflect() protoreflect.Message {
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[4]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[8]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -309,7 +566,7 @@ func (x *WorkloadIdentitySPIFFEX509) ProtoReflect() protoreflect.Message {
// Deprecated: Use WorkloadIdentitySPIFFEX509.ProtoReflect.Descriptor instead.
func (*WorkloadIdentitySPIFFEX509) Descriptor() ([]byte, []int) {
- return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{4}
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{8}
}
func (x *WorkloadIdentitySPIFFEX509) GetDnsSans() []string {
@@ -341,7 +598,7 @@ type WorkloadIdentitySPIFFE struct {
func (x *WorkloadIdentitySPIFFE) Reset() {
*x = WorkloadIdentitySPIFFE{}
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[5]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -353,7 +610,7 @@ func (x *WorkloadIdentitySPIFFE) String() string {
func (*WorkloadIdentitySPIFFE) ProtoMessage() {}
func (x *WorkloadIdentitySPIFFE) ProtoReflect() protoreflect.Message {
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[5]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[9]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -366,7 +623,7 @@ func (x *WorkloadIdentitySPIFFE) ProtoReflect() protoreflect.Message {
// Deprecated: Use WorkloadIdentitySPIFFE.ProtoReflect.Descriptor instead.
func (*WorkloadIdentitySPIFFE) Descriptor() ([]byte, []int) {
- return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{5}
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{9}
}
func (x *WorkloadIdentitySPIFFE) GetId() string {
@@ -404,7 +661,7 @@ type WorkloadIdentitySpec struct {
func (x *WorkloadIdentitySpec) Reset() {
*x = WorkloadIdentitySpec{}
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[6]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[10]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -416,7 +673,7 @@ func (x *WorkloadIdentitySpec) String() string {
func (*WorkloadIdentitySpec) ProtoMessage() {}
func (x *WorkloadIdentitySpec) ProtoReflect() protoreflect.Message {
- mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[6]
+ mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[10]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -429,7 +686,7 @@ func (x *WorkloadIdentitySpec) ProtoReflect() protoreflect.Message {
// Deprecated: Use WorkloadIdentitySpec.ProtoReflect.Descriptor instead.
func (*WorkloadIdentitySpec) Descriptor() ([]byte, []int) {
- return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{6}
+ return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{10}
}
func (x *WorkloadIdentitySpec) GetRules() *WorkloadIdentityRules {
@@ -469,56 +726,91 @@ var file_teleport_workloadidentity_v1_resource_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x0b, 0x32, 0x32, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e,
0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79,
0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e,
- 0x74, 0x69, 0x74, 0x79, 0x53, 0x70, 0x65, 0x63, 0x52, 0x04, 0x73, 0x70, 0x65, 0x63, 0x22, 0x51,
- 0x0a, 0x19, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69,
- 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x61,
- 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09,
- 0x61, 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x65, 0x71, 0x75,
- 0x61, 0x6c, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x65, 0x71, 0x75, 0x61, 0x6c,
- 0x73, 0x22, 0x6f, 0x0a, 0x14, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65,
- 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x57, 0x0a, 0x0a, 0x63, 0x6f, 0x6e,
- 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x37, 0x2e,
- 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61,
- 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72,
- 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e,
- 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f,
- 0x6e, 0x73, 0x22, 0x61, 0x0a, 0x15, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64,
- 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x48, 0x0a, 0x05, 0x61,
- 0x6c, 0x6c, 0x6f, 0x77, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x32, 0x2e, 0x74, 0x65, 0x6c,
- 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64,
- 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f,
- 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x05,
- 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x22, 0x37, 0x0a, 0x1a, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61,
- 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x58,
- 0x35, 0x30, 0x39, 0x12, 0x19, 0x0a, 0x08, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x61, 0x6e, 0x73, 0x18,
- 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6e, 0x73, 0x53, 0x61, 0x6e, 0x73, 0x22, 0x8a,
- 0x01, 0x0a, 0x16, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74,
- 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18,
- 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x69, 0x6e,
- 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x69, 0x6e, 0x74, 0x12, 0x4c, 0x0a,
- 0x04, 0x78, 0x35, 0x30, 0x39, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x38, 0x2e, 0x74, 0x65,
- 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69,
- 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c,
- 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46,
- 0x45, 0x58, 0x35, 0x30, 0x39, 0x52, 0x04, 0x78, 0x35, 0x30, 0x39, 0x22, 0xaf, 0x01, 0x0a, 0x14,
- 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79,
- 0x53, 0x70, 0x65, 0x63, 0x12, 0x49, 0x0a, 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20,
- 0x01, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77,
+ 0x74, 0x69, 0x74, 0x79, 0x53, 0x70, 0x65, 0x63, 0x52, 0x04, 0x73, 0x70, 0x65, 0x63, 0x22, 0x33,
+ 0x0a, 0x1b, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69,
+ 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x71, 0x12, 0x14, 0x0a,
+ 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61,
+ 0x6c, 0x75, 0x65, 0x22, 0x36, 0x0a, 0x1e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49,
+ 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e,
+ 0x4e, 0x6f, 0x74, 0x45, 0x71, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01,
+ 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x35, 0x0a, 0x1b, 0x57,
+ 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43,
+ 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61,
+ 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75,
+ 0x65, 0x73, 0x22, 0x38, 0x0a, 0x1e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64,
+ 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x4e,
+ 0x6f, 0x74, 0x49, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01,
+ 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x22, 0x9b, 0x03, 0x0a,
+ 0x19, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74,
+ 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x74,
+ 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61,
+ 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x12, 0x4b, 0x0a, 0x02, 0x65, 0x71, 0x18, 0x03,
+ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x39, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e,
+ 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79,
+ 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e,
+ 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x71, 0x48,
+ 0x00, 0x52, 0x02, 0x65, 0x71, 0x12, 0x55, 0x0a, 0x06, 0x6e, 0x6f, 0x74, 0x5f, 0x65, 0x71, 0x18,
+ 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x3c, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74,
+ 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74,
+ 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65,
+ 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x4e, 0x6f,
+ 0x74, 0x45, 0x71, 0x48, 0x00, 0x52, 0x05, 0x6e, 0x6f, 0x74, 0x45, 0x71, 0x12, 0x4b, 0x0a, 0x02,
+ 0x69, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x39, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70,
+ 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e,
+ 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64,
+ 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f,
+ 0x6e, 0x49, 0x6e, 0x48, 0x00, 0x52, 0x02, 0x69, 0x6e, 0x12, 0x55, 0x0a, 0x06, 0x6e, 0x6f, 0x74,
+ 0x5f, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x3c, 0x2e, 0x74, 0x65, 0x6c, 0x65,
+ 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65,
+ 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61,
+ 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69,
+ 0x6f, 0x6e, 0x4e, 0x6f, 0x74, 0x49, 0x6e, 0x48, 0x00, 0x52, 0x05, 0x6e, 0x6f, 0x74, 0x49, 0x6e,
+ 0x42, 0x0a, 0x0a, 0x08, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x4a, 0x04, 0x08, 0x02,
+ 0x10, 0x03, 0x52, 0x06, 0x65, 0x71, 0x75, 0x61, 0x6c, 0x73, 0x22, 0x6f, 0x0a, 0x14, 0x57, 0x6f,
+ 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75,
+ 0x6c, 0x65, 0x12, 0x57, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x73,
+ 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x37, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72,
+ 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69,
+ 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64,
+ 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x52,
+ 0x0a, 0x63, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x61, 0x0a, 0x15, 0x57,
+ 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52,
+ 0x75, 0x6c, 0x65, 0x73, 0x12, 0x48, 0x0a, 0x05, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x18, 0x01, 0x20,
+ 0x03, 0x28, 0x0b, 0x32, 0x32, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77,
0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e,
0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74,
- 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52, 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x12,
- 0x4c, 0x0a, 0x06, 0x73, 0x70, 0x69, 0x66, 0x66, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
- 0x34, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c,
- 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57,
- 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53,
- 0x50, 0x49, 0x46, 0x46, 0x45, 0x52, 0x06, 0x73, 0x70, 0x69, 0x66, 0x66, 0x65, 0x42, 0x64, 0x5a,
- 0x62, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x61, 0x76,
- 0x69, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f,
- 0x72, 0x74, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
- 0x2f, 0x67, 0x6f, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x6f, 0x72,
- 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2f, 0x76, 0x31,
- 0x3b, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74,
- 0x79, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+ 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x05, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x22, 0x37,
+ 0x0a, 0x1a, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69,
+ 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x58, 0x35, 0x30, 0x39, 0x12, 0x19, 0x0a, 0x08,
+ 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x61, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07,
+ 0x64, 0x6e, 0x73, 0x53, 0x61, 0x6e, 0x73, 0x22, 0x8a, 0x01, 0x0a, 0x16, 0x57, 0x6f, 0x72, 0x6b,
+ 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46,
+ 0x46, 0x45, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02,
+ 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x04, 0x68, 0x69, 0x6e, 0x74, 0x12, 0x4c, 0x0a, 0x04, 0x78, 0x35, 0x30, 0x39, 0x18, 0x03,
+ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x38, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e,
+ 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79,
+ 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e,
+ 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x58, 0x35, 0x30, 0x39, 0x52, 0x04,
+ 0x78, 0x35, 0x30, 0x39, 0x22, 0xaf, 0x01, 0x0a, 0x14, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61,
+ 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x70, 0x65, 0x63, 0x12, 0x49, 0x0a,
+ 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x74,
+ 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64,
+ 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b,
+ 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65,
+ 0x73, 0x52, 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x4c, 0x0a, 0x06, 0x73, 0x70, 0x69, 0x66,
+ 0x66, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x34, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70,
+ 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e,
+ 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64,
+ 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x52, 0x06,
+ 0x73, 0x70, 0x69, 0x66, 0x66, 0x65, 0x42, 0x64, 0x5a, 0x62, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
+ 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x61, 0x76, 0x69, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e,
+ 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x61, 0x70, 0x69, 0x2f,
+ 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f, 0x74, 0x65, 0x6c,
+ 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64,
+ 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2f, 0x76, 0x31, 0x3b, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f,
+ 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72,
+ 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -533,30 +825,38 @@ func file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP() []byte {
return file_teleport_workloadidentity_v1_resource_proto_rawDescData
}
-var file_teleport_workloadidentity_v1_resource_proto_msgTypes = make([]protoimpl.MessageInfo, 7)
+var file_teleport_workloadidentity_v1_resource_proto_msgTypes = make([]protoimpl.MessageInfo, 11)
var file_teleport_workloadidentity_v1_resource_proto_goTypes = []any{
- (*WorkloadIdentity)(nil), // 0: teleport.workloadidentity.v1.WorkloadIdentity
- (*WorkloadIdentityCondition)(nil), // 1: teleport.workloadidentity.v1.WorkloadIdentityCondition
- (*WorkloadIdentityRule)(nil), // 2: teleport.workloadidentity.v1.WorkloadIdentityRule
- (*WorkloadIdentityRules)(nil), // 3: teleport.workloadidentity.v1.WorkloadIdentityRules
- (*WorkloadIdentitySPIFFEX509)(nil), // 4: teleport.workloadidentity.v1.WorkloadIdentitySPIFFEX509
- (*WorkloadIdentitySPIFFE)(nil), // 5: teleport.workloadidentity.v1.WorkloadIdentitySPIFFE
- (*WorkloadIdentitySpec)(nil), // 6: teleport.workloadidentity.v1.WorkloadIdentitySpec
- (*v1.Metadata)(nil), // 7: teleport.header.v1.Metadata
+ (*WorkloadIdentity)(nil), // 0: teleport.workloadidentity.v1.WorkloadIdentity
+ (*WorkloadIdentityConditionEq)(nil), // 1: teleport.workloadidentity.v1.WorkloadIdentityConditionEq
+ (*WorkloadIdentityConditionNotEq)(nil), // 2: teleport.workloadidentity.v1.WorkloadIdentityConditionNotEq
+ (*WorkloadIdentityConditionIn)(nil), // 3: teleport.workloadidentity.v1.WorkloadIdentityConditionIn
+ (*WorkloadIdentityConditionNotIn)(nil), // 4: teleport.workloadidentity.v1.WorkloadIdentityConditionNotIn
+ (*WorkloadIdentityCondition)(nil), // 5: teleport.workloadidentity.v1.WorkloadIdentityCondition
+ (*WorkloadIdentityRule)(nil), // 6: teleport.workloadidentity.v1.WorkloadIdentityRule
+ (*WorkloadIdentityRules)(nil), // 7: teleport.workloadidentity.v1.WorkloadIdentityRules
+ (*WorkloadIdentitySPIFFEX509)(nil), // 8: teleport.workloadidentity.v1.WorkloadIdentitySPIFFEX509
+ (*WorkloadIdentitySPIFFE)(nil), // 9: teleport.workloadidentity.v1.WorkloadIdentitySPIFFE
+ (*WorkloadIdentitySpec)(nil), // 10: teleport.workloadidentity.v1.WorkloadIdentitySpec
+ (*v1.Metadata)(nil), // 11: teleport.header.v1.Metadata
}
var file_teleport_workloadidentity_v1_resource_proto_depIdxs = []int32{
- 7, // 0: teleport.workloadidentity.v1.WorkloadIdentity.metadata:type_name -> teleport.header.v1.Metadata
- 6, // 1: teleport.workloadidentity.v1.WorkloadIdentity.spec:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySpec
- 1, // 2: teleport.workloadidentity.v1.WorkloadIdentityRule.conditions:type_name -> teleport.workloadidentity.v1.WorkloadIdentityCondition
- 2, // 3: teleport.workloadidentity.v1.WorkloadIdentityRules.allow:type_name -> teleport.workloadidentity.v1.WorkloadIdentityRule
- 4, // 4: teleport.workloadidentity.v1.WorkloadIdentitySPIFFE.x509:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySPIFFEX509
- 3, // 5: teleport.workloadidentity.v1.WorkloadIdentitySpec.rules:type_name -> teleport.workloadidentity.v1.WorkloadIdentityRules
- 5, // 6: teleport.workloadidentity.v1.WorkloadIdentitySpec.spiffe:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySPIFFE
- 7, // [7:7] is the sub-list for method output_type
- 7, // [7:7] is the sub-list for method input_type
- 7, // [7:7] is the sub-list for extension type_name
- 7, // [7:7] is the sub-list for extension extendee
- 0, // [0:7] is the sub-list for field type_name
+ 11, // 0: teleport.workloadidentity.v1.WorkloadIdentity.metadata:type_name -> teleport.header.v1.Metadata
+ 10, // 1: teleport.workloadidentity.v1.WorkloadIdentity.spec:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySpec
+ 1, // 2: teleport.workloadidentity.v1.WorkloadIdentityCondition.eq:type_name -> teleport.workloadidentity.v1.WorkloadIdentityConditionEq
+ 2, // 3: teleport.workloadidentity.v1.WorkloadIdentityCondition.not_eq:type_name -> teleport.workloadidentity.v1.WorkloadIdentityConditionNotEq
+ 3, // 4: teleport.workloadidentity.v1.WorkloadIdentityCondition.in:type_name -> teleport.workloadidentity.v1.WorkloadIdentityConditionIn
+ 4, // 5: teleport.workloadidentity.v1.WorkloadIdentityCondition.not_in:type_name -> teleport.workloadidentity.v1.WorkloadIdentityConditionNotIn
+ 5, // 6: teleport.workloadidentity.v1.WorkloadIdentityRule.conditions:type_name -> teleport.workloadidentity.v1.WorkloadIdentityCondition
+ 6, // 7: teleport.workloadidentity.v1.WorkloadIdentityRules.allow:type_name -> teleport.workloadidentity.v1.WorkloadIdentityRule
+ 8, // 8: teleport.workloadidentity.v1.WorkloadIdentitySPIFFE.x509:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySPIFFEX509
+ 7, // 9: teleport.workloadidentity.v1.WorkloadIdentitySpec.rules:type_name -> teleport.workloadidentity.v1.WorkloadIdentityRules
+ 9, // 10: teleport.workloadidentity.v1.WorkloadIdentitySpec.spiffe:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySPIFFE
+ 11, // [11:11] is the sub-list for method output_type
+ 11, // [11:11] is the sub-list for method input_type
+ 11, // [11:11] is the sub-list for extension type_name
+ 11, // [11:11] is the sub-list for extension extendee
+ 0, // [0:11] is the sub-list for field type_name
}
func init() { file_teleport_workloadidentity_v1_resource_proto_init() }
@@ -564,13 +864,19 @@ func file_teleport_workloadidentity_v1_resource_proto_init() {
if File_teleport_workloadidentity_v1_resource_proto != nil {
return
}
+ file_teleport_workloadidentity_v1_resource_proto_msgTypes[5].OneofWrappers = []any{
+ (*WorkloadIdentityCondition_Eq)(nil),
+ (*WorkloadIdentityCondition_NotEq)(nil),
+ (*WorkloadIdentityCondition_In)(nil),
+ (*WorkloadIdentityCondition_NotIn)(nil),
+ }
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_teleport_workloadidentity_v1_resource_proto_rawDesc,
NumEnums: 0,
- NumMessages: 7,
+ NumMessages: 11,
NumExtensions: 0,
NumServices: 0,
},
diff --git a/api/proto/teleport/workloadidentity/v1/resource.proto b/api/proto/teleport/workloadidentity/v1/resource.proto
index b0faf7f94b99e..ad4cc03cf4c24 100644
--- a/api/proto/teleport/workloadidentity/v1/resource.proto
+++ b/api/proto/teleport/workloadidentity/v1/resource.proto
@@ -38,12 +38,46 @@ message WorkloadIdentity {
WorkloadIdentitySpec spec = 5;
}
+// The attribute casted to a string must be equal to the value.
+message WorkloadIdentityConditionEq {
+ // The value to compare the attribute against.
+ string value = 1;
+}
+
+// The attribute casted to a string must not be equal to the value.
+message WorkloadIdentityConditionNotEq {
+ // The value to compare the attribute against.
+ string value = 1;
+}
+
+// The attribute casted to a string must be in the list of values.
+message WorkloadIdentityConditionIn {
+ // The list of values to compare the attribute against.
+ repeated string values = 1;
+}
+
+// The attribute casted to a string must not be in the list of values.
+message WorkloadIdentityConditionNotIn {
+ // The list of values to compare the attribute against.
+ repeated string values = 1;
+}
+
// The individual conditions that make up a rule.
message WorkloadIdentityCondition {
+ reserved 2;
+ reserved "equals";
// The name of the attribute to evaluate the condition against.
string attribute = 1;
- // An exact string that the attribute must match.
- string equals = 2;
+ oneof operator {
+ // The attribute casted to a string must be equal to the value.
+ WorkloadIdentityConditionEq eq = 3;
+ // The attribute casted to a string must not be equal to the value.
+ WorkloadIdentityConditionNotEq not_eq = 4;
+ // The attribute casted to a string must be in the list of values.
+ WorkloadIdentityConditionIn in = 5;
+ // The attribute casted to a string must not be in the list of values.
+ WorkloadIdentityConditionNotIn not_in = 6;
+ }
}
// An individual rule that is evaluated during the issuance of a WorkloadIdentity.
diff --git a/docs/pages/admin-guides/access-controls/access-monitoring.mdx b/docs/pages/admin-guides/access-controls/access-monitoring.mdx
index 7f5a7b2a0a864..25797cf3e89d3 100644
--- a/docs/pages/admin-guides/access-controls/access-monitoring.mdx
+++ b/docs/pages/admin-guides/access-controls/access-monitoring.mdx
@@ -17,7 +17,7 @@ Users are able to write their own custom access monitoring queries by querying t
Access Monitoring is not currently supported with External Audit Storage
- in Teleport Enterprise (cloud-hosted). This functionality will be
+ in Teleport Enterprise (Cloud). This functionality will be
enabled in a future Teleport release.
diff --git a/docs/pages/admin-guides/management/external-audit-storage.mdx b/docs/pages/admin-guides/management/external-audit-storage.mdx
index 6aa2fcc0368b8..587bb7ffebe56 100644
--- a/docs/pages/admin-guides/management/external-audit-storage.mdx
+++ b/docs/pages/admin-guides/management/external-audit-storage.mdx
@@ -21,6 +21,12 @@ External Audit Storage is based on Teleport's
available on Teleport Enterprise Cloud clusters running Teleport v14.2.1 or
above.
+
+On Teleport Enterprise (Cloud), External Audit
+Storage is not currently supported for users who have Access Monitoring enabled.
+This functionality will be enabled in a future Teleport release.
+
+
## Prerequisites
1. A Teleport Enterprise Cloud account. If you do not have one, [sign
diff --git a/docs/pages/reference/terraform-provider/data-sources/workload_identity.mdx b/docs/pages/reference/terraform-provider/data-sources/workload_identity.mdx
index 7c7e1a05a5af0..6a5f12830bf4f 100644
--- a/docs/pages/reference/terraform-provider/data-sources/workload_identity.mdx
+++ b/docs/pages/reference/terraform-provider/data-sources/workload_identity.mdx
@@ -55,7 +55,38 @@ Optional:
Optional:
- `attribute` (String) The name of the attribute to evaluate the condition against.
-- `equals` (String) An exact string that the attribute must match.
+- `eq` (Attributes) The attribute casted to a string must be equal to the value. (see [below for nested schema](#nested-schema-for-specrulesallowconditionseq))
+- `in` (Attributes) The attribute casted to a string must be in the list of values. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsin))
+- `not_eq` (Attributes) The attribute casted to a string must not be equal to the value. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsnot_eq))
+- `not_in` (Attributes) The attribute casted to a string must not be in the list of values. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsnot_in))
+
+### Nested Schema for `spec.rules.allow.conditions.eq`
+
+Optional:
+
+- `value` (String) The value to compare the attribute against.
+
+
+### Nested Schema for `spec.rules.allow.conditions.in`
+
+Optional:
+
+- `values` (List of String) The list of values to compare the attribute against.
+
+
+### Nested Schema for `spec.rules.allow.conditions.not_eq`
+
+Optional:
+
+- `value` (String) The value to compare the attribute against.
+
+
+### Nested Schema for `spec.rules.allow.conditions.not_in`
+
+Optional:
+
+- `values` (List of String) The list of values to compare the attribute against.
+
diff --git a/docs/pages/reference/terraform-provider/resources/workload_identity.mdx b/docs/pages/reference/terraform-provider/resources/workload_identity.mdx
index fbbeb1306abd8..6238a0d535b03 100644
--- a/docs/pages/reference/terraform-provider/resources/workload_identity.mdx
+++ b/docs/pages/reference/terraform-provider/resources/workload_identity.mdx
@@ -23,7 +23,9 @@ resource "teleport_workload_identity" "example" {
{
conditions = [{
attribute = "user.name"
- equals = "noah"
+ eq = {
+ value = "my-user"
+ }
}]
}
]
@@ -80,7 +82,38 @@ Optional:
Optional:
- `attribute` (String) The name of the attribute to evaluate the condition against.
-- `equals` (String) An exact string that the attribute must match.
+- `eq` (Attributes) The attribute casted to a string must be equal to the value. (see [below for nested schema](#nested-schema-for-specrulesallowconditionseq))
+- `in` (Attributes) The attribute casted to a string must be in the list of values. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsin))
+- `not_eq` (Attributes) The attribute casted to a string must not be equal to the value. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsnot_eq))
+- `not_in` (Attributes) The attribute casted to a string must not be in the list of values. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsnot_in))
+
+### Nested Schema for `spec.rules.allow.conditions.eq`
+
+Optional:
+
+- `value` (String) The value to compare the attribute against.
+
+
+### Nested Schema for `spec.rules.allow.conditions.in`
+
+Optional:
+
+- `values` (List of String) The list of values to compare the attribute against.
+
+
+### Nested Schema for `spec.rules.allow.conditions.not_eq`
+
+Optional:
+
+- `value` (String) The value to compare the attribute against.
+
+
+### Nested Schema for `spec.rules.allow.conditions.not_in`
+
+Optional:
+
+- `values` (List of String) The list of values to compare the attribute against.
+
diff --git a/go.mod b/go.mod
index 3c35132910093..625a780eb3ff6 100644
--- a/go.mod
+++ b/go.mod
@@ -48,6 +48,7 @@ require (
github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.23
github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.58
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22
+ github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45
github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.3
github.com/aws/aws-sdk-go-v2/service/athena v1.49.2
@@ -179,7 +180,6 @@ require (
github.com/sigstore/cosign/v2 v2.4.1
github.com/sigstore/sigstore v1.8.11
github.com/sijms/go-ora/v2 v2.8.22
- github.com/sirupsen/logrus v1.9.3
github.com/snowflakedb/gosnowflake v1.12.1
github.com/spf13/cobra v1.8.1
github.com/spiffe/go-spiffe/v2 v2.4.0
@@ -501,6 +501,7 @@ require (
github.com/sigstore/protobuf-specs v0.3.2 // indirect
github.com/sigstore/rekor v1.3.6 // indirect
github.com/sigstore/timestamp-authority v1.2.2 // indirect
+ github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
diff --git a/go.sum b/go.sum
index 5665c4f7280c7..5bf38ba7fc0c4 100644
--- a/go.sum
+++ b/go.sum
@@ -866,6 +866,8 @@ github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.58/go.
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.3/go.mod h1:4Q0UFP0YJf0NrsEuEYHpM9fTSEVnD16Z3uyEF7J9JGM=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 h1:kqOrpojG71DxJm/KDPO+Z/y1phm1JlC8/iT+5XRmAn8=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22/go.mod h1:NtSFajXVVL8TA2QNngagVZmUtXciyrHOt7xgz4faS/M=
+github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 h1:fo+GuZNME9oGDc7VY+EBT+oCrco6RjRgUp1bKTcaHrU=
+github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2/go.mod h1:fnqb94UO6YCjBIic4WaqDYkNVAEFWOWiReVHitBBWW0=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45 h1:ZxB8WFVYwolhDZxuZXoesHkl+L9cXLWy0K/G0QkNATc=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45/go.mod h1:1krrbyoFFDqaNldmltPTP+mK3sAXLHPoaFtISOw2Hkk=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw=
diff --git a/integrations/access/accesslist/app.go b/integrations/access/accesslist/app.go
index 02f933baf5ecd..ba40de3abf575 100644
--- a/integrations/access/accesslist/app.go
+++ b/integrations/access/accesslist/app.go
@@ -33,6 +33,7 @@ import (
"github.com/gravitational/teleport/integrations/lib"
"github.com/gravitational/teleport/integrations/lib/logger"
pd "github.com/gravitational/teleport/integrations/lib/plugindata"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
const (
@@ -118,7 +119,7 @@ func (a *App) run(ctx context.Context) error {
log := logger.Get(ctx)
- log.Info("Access list monitor is running")
+ log.InfoContext(ctx, "Access list monitor is running")
a.job.SetReady(true)
@@ -134,7 +135,7 @@ func (a *App) run(ctx context.Context) error {
}
timer.Reset(jitter(reminderInterval))
case <-ctx.Done():
- log.Info("Access list monitor is finished")
+ log.InfoContext(ctx, "Access list monitor is finished")
return nil
}
}
@@ -146,7 +147,7 @@ func (a *App) run(ctx context.Context) error {
func (a *App) remindIfNecessary(ctx context.Context) error {
log := logger.Get(ctx)
- log.Info("Looking for Access List Review reminders")
+ log.InfoContext(ctx, "Looking for Access List Review reminders")
var nextToken string
var err error
@@ -156,13 +157,14 @@ func (a *App) remindIfNecessary(ctx context.Context) error {
accessLists, nextToken, err = a.apiClient.ListAccessLists(ctx, 0 /* default page size */, nextToken)
if err != nil {
if trace.IsNotImplemented(err) {
- log.Errorf("access list endpoint is not implemented on this auth server, so the access list app is ceasing to run.")
+ log.ErrorContext(ctx, "access list endpoint is not implemented on this auth server, so the access list app is ceasing to run")
return trace.Wrap(err)
} else if trace.IsAccessDenied(err) {
- log.Warnf("Slack bot does not have permissions to list access lists. Please add access_list read and list permissions " +
- "to the role associated with the Slack bot.")
+ const msg = "Slack bot does not have permissions to list access lists. Please add access_list read and list permissions " +
+ "to the role associated with the Slack bot."
+ log.WarnContext(ctx, msg)
} else {
- log.Errorf("error listing access lists: %v", err)
+ log.ErrorContext(ctx, "error listing access lists", "error", err)
}
break
}
@@ -170,7 +172,10 @@ func (a *App) remindIfNecessary(ctx context.Context) error {
for _, accessList := range accessLists {
recipients, err := a.getRecipientsRequiringReminders(ctx, accessList)
if err != nil {
- log.WithError(err).Warnf("Error getting recipients to notify for review due for access list %q", accessList.Spec.Title)
+ log.WarnContext(ctx, "Error getting recipients to notify for review due for access list",
+ "error", err,
+ "access_list", accessList.Spec.Title,
+ )
continue
}
@@ -195,7 +200,7 @@ func (a *App) remindIfNecessary(ctx context.Context) error {
}
if len(errs) > 0 {
- log.WithError(trace.NewAggregate(errs...)).Warn("Error notifying for access list reviews")
+ log.WarnContext(ctx, "Error notifying for access list reviews", "error", trace.NewAggregate(errs...))
}
return nil
@@ -213,7 +218,10 @@ func (a *App) getRecipientsRequiringReminders(ctx context.Context, accessList *a
// If the current time before the notification start time, skip notifications.
if now.Before(notificationStart) {
- log.Debugf("Access list %s is not ready for notifications, notifications start at %s", accessList.GetName(), notificationStart.Format(time.RFC3339))
+ log.DebugContext(ctx, "Access list is not ready for notifications",
+ "access_list", accessList.GetName(),
+ "notification_start_time", notificationStart.Format(time.RFC3339),
+ )
return nil, nil
}
@@ -255,12 +263,17 @@ func (a *App) fetchRecipients(ctx context.Context, accessList *accesslist.Access
if err != nil {
// TODO(kiosion): Remove in v18; protecting against server not having `GetAccessListOwners` func.
if trace.IsNotImplemented(err) {
- log.WithError(err).Warnf("Error getting nested owners for access list '%v', continuing with only explicit owners", accessList.GetName())
+ log.WarnContext(ctx, "Error getting nested owners for access list, continuing with only explicit owners",
+ "error", err,
+ "access_list", accessList.GetName(),
+ )
for _, owner := range accessList.Spec.Owners {
allOwners = append(allOwners, &owner)
}
} else {
- log.WithError(err).Errorf("Error getting owners for access list '%v'", accessList.GetName())
+ log.ErrorContext(ctx, "Error getting owners for access list",
+ "error", err,
+ "access_list", accessList.GetName())
}
}
@@ -270,7 +283,7 @@ func (a *App) fetchRecipients(ctx context.Context, accessList *accesslist.Access
for _, owner := range allOwners {
recipient, err := a.bot.FetchRecipient(ctx, owner.Name)
if err != nil {
- log.Debugf("error getting recipient %s", owner.Name)
+ log.DebugContext(ctx, "error getting recipient", "recipient", owner.Name)
continue
}
allRecipients[owner.Name] = *recipient
@@ -293,7 +306,10 @@ func (a *App) updatePluginDataAndGetRecipientsRequiringReminders(ctx context.Con
// Calculate days from start.
daysFromStart := now.Sub(notificationStart) / oneDay
windowStart = notificationStart.Add(daysFromStart * oneDay)
- log.Infof("windowStart: %s, now: %s", windowStart.String(), now.String())
+ log.InfoContext(ctx, "calculating window start",
+ "window_start", logutils.StringerAttr(windowStart),
+ "now", logutils.StringerAttr(now),
+ )
}
recipients := []common.Recipient{}
@@ -304,7 +320,10 @@ func (a *App) updatePluginDataAndGetRecipientsRequiringReminders(ctx context.Con
// If the notification window is before the last notification date, then this user doesn't need a notification.
if !windowStart.After(lastNotification) {
- log.Debugf("User %s has already been notified for access list %s", recipient.Name, accessList.GetName())
+ log.DebugContext(ctx, "User has already been notified for access list",
+ "user", recipient.Name,
+ "access_list", accessList.GetName(),
+ )
userNotifications[recipient.Name] = lastNotification
continue
}
diff --git a/integrations/access/accessmonitoring/access_monitoring_rules.go b/integrations/access/accessmonitoring/access_monitoring_rules.go
index 3dea9ea2bf543..82c91413bff96 100644
--- a/integrations/access/accessmonitoring/access_monitoring_rules.go
+++ b/integrations/access/accessmonitoring/access_monitoring_rules.go
@@ -151,8 +151,10 @@ func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context
for _, rule := range amrh.getAccessMonitoringRules() {
match, err := MatchAccessRequest(rule.Spec.Condition, req)
if err != nil {
- log.WithError(err).WithField("rule", rule.Metadata.Name).
- Warn("Failed to parse access monitoring notification rule")
+ log.WarnContext(ctx, "Failed to parse access monitoring notification rule",
+ "error", err,
+ "rule", rule.Metadata.Name,
+ )
}
if !match {
continue
@@ -160,7 +162,7 @@ func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context
for _, recipient := range rule.Spec.Notification.Recipients {
rec, err := amrh.fetchRecipientCallback(ctx, recipient)
if err != nil {
- log.WithError(err).Warn("Failed to fetch plugin recipients based on Access monitoring rule recipients")
+ log.WarnContext(ctx, "Failed to fetch plugin recipients based on Access monitoring rule recipients", "error", err)
continue
}
recipientSet.Add(*rec)
@@ -176,8 +178,10 @@ func (amrh *RuleHandler) RawRecipientsFromAccessMonitoringRules(ctx context.Cont
for _, rule := range amrh.getAccessMonitoringRules() {
match, err := MatchAccessRequest(rule.Spec.Condition, req)
if err != nil {
- log.WithError(err).WithField("rule", rule.Metadata.Name).
- Warn("Failed to parse access monitoring notification rule")
+ log.WarnContext(ctx, "Failed to parse access monitoring notification rule",
+ "error", err,
+ "rule", rule.Metadata.Name,
+ )
}
if !match {
continue
diff --git a/integrations/access/accessrequest/app.go b/integrations/access/accessrequest/app.go
index 8a5effc73dabd..17182ec3dc8ee 100644
--- a/integrations/access/accessrequest/app.go
+++ b/integrations/access/accessrequest/app.go
@@ -21,6 +21,7 @@ package accessrequest
import (
"context"
"fmt"
+ "log/slog"
"slices"
"strings"
"time"
@@ -36,6 +37,7 @@ import (
"github.com/gravitational/teleport/integrations/lib/logger"
pd "github.com/gravitational/teleport/integrations/lib/plugindata"
"github.com/gravitational/teleport/integrations/lib/watcherjob"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
const (
@@ -189,16 +191,16 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error {
func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error {
op := event.Type
reqID := event.Resource.GetName()
- ctx, _ = logger.WithField(ctx, "request_id", reqID)
+ ctx, _ = logger.With(ctx, "request_id", reqID)
switch op {
case types.OpPut:
- ctx, _ = logger.WithField(ctx, "request_op", "put")
+ ctx, _ = logger.With(ctx, "request_op", "put")
req, ok := event.Resource.(types.AccessRequest)
if !ok {
return trace.BadParameter("unexpected resource type %T", event.Resource)
}
- ctx, log := logger.WithField(ctx, "request_state", req.GetState().String())
+ ctx, log := logger.With(ctx, "request_state", req.GetState().String())
var err error
switch {
@@ -207,21 +209,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
case req.GetState().IsResolved():
err = a.onResolvedRequest(ctx, req)
default:
- log.WithField("event", event).Warn("Unknown request state")
+ log.WarnContext(ctx, "Unknown request state",
+ slog.Group("event",
+ slog.Any("type", logutils.StringerAttr(event.Type)),
+ slog.Group("resource",
+ "kind", event.Resource.GetKind(),
+ "name", event.Resource.GetName(),
+ ),
+ ),
+ )
return nil
}
if err != nil {
- log.WithError(err).Errorf("Failed to process request")
+ log.ErrorContext(ctx, "Failed to process request", "error", err)
return trace.Wrap(err)
}
return nil
case types.OpDelete:
- ctx, log := logger.WithField(ctx, "request_op", "delete")
+ ctx, log := logger.With(ctx, "request_op", "delete")
if err := a.onDeletedRequest(ctx, reqID); err != nil {
- log.WithError(err).Errorf("Failed to process deleted request")
+ log.ErrorContext(ctx, "Failed to process deleted request", "error", err)
return trace.Wrap(err)
}
return nil
@@ -242,7 +252,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err
loginsByRole, err := a.getLoginsByRole(ctx, req)
if trace.IsAccessDenied(err) {
- log.Warnf("Missing permissions to get logins by role. Please add role.read to the associated role. error: %s", err)
+ log.WarnContext(ctx, "Missing permissions to get logins by role, please add role.read to the associated role", "error", err)
} else if err != nil {
return trace.Wrap(err)
}
@@ -265,12 +275,12 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err
return trace.Wrap(err)
}
} else {
- log.Warning("No channel to post")
+ log.WarnContext(ctx, "No channel to post")
}
// Try to approve the request if user is currently on-call.
if err := a.tryApproveRequest(ctx, reqID, req); err != nil {
- log.Warningf("Failed to auto approve request: %v", err)
+ log.WarnContext(ctx, "Failed to auto approve request", "error", err)
}
case trace.IsAlreadyExists(err):
// The messages were already sent, nothing to do, we can update the reviews
@@ -311,7 +321,7 @@ func (a *App) onResolvedRequest(ctx context.Context, req types.AccessRequest) er
case types.RequestState_PROMOTED:
tag = pd.ResolvedPromoted
default:
- logger.Get(ctx).Warningf("Unknown state %v (%s)", state, state.String())
+ logger.Get(ctx).WarnContext(ctx, "Unknown state", "state", logutils.StringerAttr(state))
return replyErr
}
err := trace.Wrap(a.updateMessages(ctx, req.GetName(), tag, reason, req.GetReviews()))
@@ -330,13 +340,13 @@ func (a *App) broadcastAccessRequestMessages(ctx context.Context, recipients []c
return trace.Wrap(err)
}
for _, data := range sentMessages {
- logger.Get(ctx).WithFields(logger.Fields{
- "channel_id": data.ChannelID,
- "message_id": data.MessageID,
- }).Info("Successfully posted messages")
+ logger.Get(ctx).InfoContext(ctx, "Successfully posted messages",
+ "channel_id", data.ChannelID,
+ "message_id", data.MessageID,
+ )
}
if err != nil {
- logger.Get(ctx).WithError(err).Error("Failed to post one or more messages")
+ logger.Get(ctx).ErrorContext(ctx, "Failed to post one or more messages", "error", err)
}
_, err = a.pluginData.Update(ctx, reqID, func(existing PluginData) (PluginData, error) {
@@ -369,7 +379,7 @@ func (a *App) postReviewReplies(ctx context.Context, reqID string, reqReviews []
return existing, nil
})
if trace.IsAlreadyExists(err) {
- logger.Get(ctx).Debug("Failed to post reply: replies are already sent")
+ logger.Get(ctx).DebugContext(ctx, "Failed to post reply: replies are already sent")
return nil
}
if err != nil {
@@ -383,7 +393,7 @@ func (a *App) postReviewReplies(ctx context.Context, reqID string, reqReviews []
errors := make([]error, 0, len(slice))
for _, data := range pd.SentMessages {
- ctx, _ = logger.WithFields(ctx, logger.Fields{"channel_id": data.ChannelID, "message_id": data.MessageID})
+ ctx, _ = logger.With(ctx, "channel_id", data.ChannelID, "message_id", data.MessageID)
for _, review := range slice {
if err := a.bot.PostReviewReply(ctx, data.ChannelID, data.MessageID, review); err != nil {
errors = append(errors, err)
@@ -425,7 +435,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest)
for _, recipient := range recipients {
rec, err := a.bot.FetchRecipient(ctx, recipient)
if err != nil {
- log.Warningf("Failed to fetch Opsgenie recipient: %v", err)
+ log.WarnContext(ctx, "Failed to fetch Opsgenie recipient", "error", err)
continue
}
recipientSet.Add(*rec)
@@ -436,7 +446,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest)
validEmailSuggReviewers := []string{}
for _, reviewer := range req.GetSuggestedReviewers() {
if !lib.IsEmail(reviewer) {
- log.Warningf("Failed to notify a suggested reviewer: %q does not look like a valid email", reviewer)
+ log.WarnContext(ctx, "Failed to notify a suggested reviewer with an invalid email address", "reviewer", reviewer)
continue
}
@@ -446,7 +456,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest)
for _, rawRecipient := range rawRecipients {
recipient, err := a.bot.FetchRecipient(ctx, rawRecipient)
if err != nil {
- log.WithError(err).Warn("Failure when fetching recipient, continuing anyway")
+ log.WarnContext(ctx, "Failure when fetching recipient, continuing anyway", "error", err)
} else {
recipientSet.Add(*recipient)
}
@@ -476,7 +486,7 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio
return existing, nil
})
if trace.IsNotFound(err) {
- log.Debug("Failed to update messages: plugin data is missing")
+ log.DebugContext(ctx, "Failed to update messages: plugin data is missing")
return nil
}
if trace.IsAlreadyExists(err) {
@@ -485,7 +495,7 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio
"cannot change the resolution tag of an already resolved request, existing: %s, event: %s",
pluginData.ResolutionTag, tag)
}
- log.Debug("Request is already resolved, ignoring event")
+ log.DebugContext(ctx, "Request is already resolved, ignoring event")
return nil
}
if err != nil {
@@ -496,13 +506,17 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio
if err := a.bot.UpdateMessages(ctx, reqID, reqData, sentMessages, reviews); err != nil {
return trace.Wrap(err)
}
- log.Infof("Successfully marked request as %s in all messages", tag)
+
+ log.InfoContext(ctx, "Marked request with resolution and sent emails!", "resolution", tag)
if err := a.bot.NotifyUser(ctx, reqID, reqData); err != nil && !trace.IsNotImplemented(err) {
return trace.Wrap(err)
}
- log.Infof("Successfully notified user %s request marked as %s", reqData.User, tag)
+ log.InfoContext(ctx, "Successfully notified user",
+ "user", reqData.User,
+ "resolution", tag,
+ )
return nil
}
@@ -545,13 +559,11 @@ func (a *App) getResourceNames(ctx context.Context, req types.AccessRequest) ([]
// tryApproveRequest attempts to automatically approve the access request if the
// user is on call for the configured service/team.
func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.AccessRequest) error {
- log := logger.Get(ctx).
- WithField("req_id", reqID).
- WithField("user", req.GetUser())
+ log := logger.Get(ctx).With("req_id", reqID, "user", req.GetUser())
oncallUsers, err := a.bot.FetchOncallUsers(ctx, req)
if trace.IsNotImplemented(err) {
- log.Debugf("Skipping auto-approval because %q bot does not support automatic approvals.", a.pluginName)
+ log.DebugContext(ctx, "Skipping auto-approval because bot does not support automatic approvals", "bot", a.pluginName)
return nil
}
if err != nil {
@@ -559,7 +571,7 @@ func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.Acc
}
if !slices.Contains(oncallUsers, req.GetUser()) {
- log.Debug("Skipping approval because user is not on-call.")
+ log.DebugContext(ctx, "Skipping approval because user is not on-call")
return nil
}
@@ -573,12 +585,12 @@ func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.Acc
},
}); err != nil {
if strings.HasSuffix(err.Error(), "has already reviewed this request") {
- log.Debug("Request has already been reviewed.")
+ log.DebugContext(ctx, "Request has already been reviewed")
return nil
}
return trace.Wrap(err)
}
- log.Info("Successfully submitted a request approval.")
+ log.InfoContext(ctx, "Successfully submitted a request approval")
return nil
}
diff --git a/integrations/access/common/app.go b/integrations/access/common/app.go
index 805c0dde6ef8a..6c174e1422b75 100644
--- a/integrations/access/common/app.go
+++ b/integrations/access/common/app.go
@@ -88,7 +88,7 @@ func (a *BaseApp) WaitReady(ctx context.Context) (bool, error) {
func (a *BaseApp) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) {
log := logger.Get(ctx)
- log.Debug("Checking Teleport server version")
+ log.DebugContext(ctx, "Checking Teleport server version")
pong, err := a.APIClient.Ping(ctx)
if err != nil {
@@ -156,9 +156,9 @@ func (a *BaseApp) run(ctx context.Context) error {
a.mainJob.SetReady(allOK)
if allOK {
- log.Info("Plugin is ready")
+ log.InfoContext(ctx, "Plugin is ready")
} else {
- log.Error("Plugin is not ready")
+ log.ErrorContext(ctx, "Plugin is not ready")
}
for _, app := range a.apps {
@@ -203,11 +203,11 @@ func (a *BaseApp) init(ctx context.Context) error {
}
}
- log.Debug("Starting API health check...")
+ log.DebugContext(ctx, "Starting API health check")
if err = a.Bot.CheckHealth(ctx); err != nil {
return trace.Wrap(err, "API health check failed")
}
- log.Debug("API health check finished ok")
+ log.DebugContext(ctx, "API health check finished ok")
return nil
}
diff --git a/integrations/access/common/auth/token_provider.go b/integrations/access/common/auth/token_provider.go
index f4ae33936a709..e0c23b0b36427 100644
--- a/integrations/access/common/auth/token_provider.go
+++ b/integrations/access/common/auth/token_provider.go
@@ -20,12 +20,12 @@ package auth
import (
"context"
+ "log/slog"
"sync"
"time"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
- "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/common/auth/oauth"
"github.com/gravitational/teleport/integrations/access/common/auth/storage"
@@ -65,7 +65,7 @@ type RotatedAccessTokenProviderConfig struct {
Refresher oauth.Refresher
Clock clockwork.Clock
- Log *logrus.Entry
+ Log *slog.Logger
}
// CheckAndSetDefaults validates a configuration and sets default values
@@ -87,7 +87,7 @@ func (c *RotatedAccessTokenProviderConfig) CheckAndSetDefaults() error {
c.Clock = clockwork.NewRealClock()
}
if c.Log == nil {
- c.Log = logrus.NewEntry(logrus.StandardLogger())
+ c.Log = slog.Default()
}
return nil
}
@@ -104,7 +104,7 @@ type RotatedAccessTokenProvider struct {
refresher oauth.Refresher
clock clockwork.Clock
- log logrus.FieldLogger
+ log *slog.Logger
lock sync.RWMutex // protects the below fields
creds *storage.Credentials
@@ -153,12 +153,12 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) {
timer := r.clock.NewTimer(interval)
defer timer.Stop()
- r.log.Infof("Will attempt token refresh in: %s", interval)
+ r.log.InfoContext(ctx, "Starting token refresh loop", "next_refresh", interval)
for {
select {
case <-ctx.Done():
- r.log.Info("Shutting down")
+ r.log.InfoContext(ctx, "Shutting down")
return
case <-timer.Chan():
creds, _ := r.store.GetCredentials(ctx)
@@ -174,18 +174,21 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) {
interval := r.getRefreshInterval(creds)
timer.Reset(interval)
- r.log.Infof("Next refresh in: %s", interval)
+ r.log.InfoContext(ctx, "Refreshed token", "next_refresh", interval)
continue
}
creds, err := r.refresh(ctx)
if err != nil {
- r.log.Errorf("Error while refreshing: %s. Will retry after: %s", err, r.retryInterval)
+ r.log.ErrorContext(ctx, "Error while refreshing token",
+ "error", err,
+ "retry_interval", r.retryInterval,
+ )
timer.Reset(r.retryInterval)
} else {
err := r.store.PutCredentials(ctx, creds)
if err != nil {
- r.log.Errorf("Error while storing the refreshed credentials: %s", err)
+ r.log.ErrorContext(ctx, "Error while storing the refreshed credentials", "error", err)
timer.Reset(r.retryInterval)
continue
}
@@ -196,7 +199,7 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) {
interval := r.getRefreshInterval(creds)
timer.Reset(interval)
- r.log.Infof("Successfully refreshed credentials. Next refresh in: %s", interval)
+ r.log.InfoContext(ctx, "Successfully refreshed credentials", "next_refresh", interval)
}
}
}
diff --git a/integrations/access/common/auth/token_provider_test.go b/integrations/access/common/auth/token_provider_test.go
index fca79776ba024..e4f02ec3d3ae5 100644
--- a/integrations/access/common/auth/token_provider_test.go
+++ b/integrations/access/common/auth/token_provider_test.go
@@ -20,12 +20,12 @@ package auth
import (
"context"
+ "log/slog"
"testing"
"time"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
- "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/integrations/access/common/auth/oauth"
@@ -57,9 +57,6 @@ func (s *mockStore) PutCredentials(ctx context.Context, creds *storage.Credentia
}
func TestRotatedAccessTokenProvider(t *testing.T) {
- log := logrus.New()
- log.Level = logrus.DebugLevel
-
newProvider := func(ctx context.Context, store storage.Store, refresher oauth.Refresher, clock clockwork.Clock, initialCreds *storage.Credentials) *RotatedAccessTokenProvider {
return &RotatedAccessTokenProvider{
store: store,
@@ -70,7 +67,7 @@ func TestRotatedAccessTokenProvider(t *testing.T) {
tokenBufferInterval: 1 * time.Hour,
creds: initialCreds,
- log: log,
+ log: slog.Default(),
}
}
diff --git a/integrations/access/datadog/bot.go b/integrations/access/datadog/bot.go
index e92dbbb524a20..4e1f52a6c218d 100644
--- a/integrations/access/datadog/bot.go
+++ b/integrations/access/datadog/bot.go
@@ -162,7 +162,7 @@ func (b Bot) FetchOncallUsers(ctx context.Context, req types.AccessRequest) ([]s
annotationKey := types.TeleportNamespace + types.ReqAnnotationApproveSchedulesLabel
teamNames, err := common.GetNamesFromAnnotations(req, annotationKey)
if err != nil {
- log.Debug("Automatic approvals annotation is empty or unspecified.")
+ log.DebugContext(ctx, "Automatic approvals annotation is empty or unspecified")
return nil, nil
}
diff --git a/integrations/access/datadog/client.go b/integrations/access/datadog/client.go
index 489eb0c51a44d..2d4ebf79ea5f2 100644
--- a/integrations/access/datadog/client.go
+++ b/integrations/access/datadog/client.go
@@ -126,7 +126,7 @@ func onAfterDatadogResponse(sink common.StatusSink) resty.ResponseMiddleware {
defer cancel()
if err := sink.Emit(ctx, status); err != nil {
- log.WithError(err).Errorf("Error while emitting Datadog Incident Management plugin status: %v", err)
+ log.ErrorContext(ctx, "Error while emitting Datadog Incident Management plugin status", "error", err)
}
}
diff --git a/integrations/access/datadog/cmd/teleport-datadog/main.go b/integrations/access/datadog/cmd/teleport-datadog/main.go
index cb9cbd1959771..84a6a14c0955f 100644
--- a/integrations/access/datadog/cmd/teleport-datadog/main.go
+++ b/integrations/access/datadog/cmd/teleport-datadog/main.go
@@ -22,6 +22,7 @@ import (
"context"
_ "embed"
"fmt"
+ "log/slog"
"os"
"github.com/alecthomas/kingpin/v2"
@@ -67,12 +68,13 @@ func main() {
if err := run(*path, *debug); err != nil {
lib.Bail(err)
} else {
- logger.Standard().Info("Successfully shut down")
+ slog.InfoContext(context.Background(), "Successfully shut down")
}
}
}
func run(configPath string, debug bool) error {
+ ctx := context.Background()
conf, err := datadog.LoadConfig(configPath)
if err != nil {
return trace.Wrap(err)
@@ -86,14 +88,15 @@ func run(configPath string, debug bool) error {
return err
}
if debug {
- logger.Standard().Debugf("DEBUG logging enabled")
+ slog.DebugContext(ctx, "DEBUG logging enabled")
}
app := datadog.NewDatadogApp(conf)
go lib.ServeSignals(app, common.PluginShutdownTimeout)
- logger.Standard().Infof("Starting Teleport Access Datadog Incident Management Plugin %s:%s", teleport.Version, teleport.Gitref)
- return trace.Wrap(
- app.Run(context.Background()),
+ slog.InfoContext(ctx, "Starting Teleport Access Datadog Incident Management Plugin",
+ "version", teleport.Version,
+ "git_ref", teleport.Gitref,
)
+ return trace.Wrap(app.Run(ctx))
}
diff --git a/integrations/access/datadog/testlib/fake_datadog.go b/integrations/access/datadog/testlib/fake_datadog.go
index 64ef2e35b93b7..5cfe8b539f454 100644
--- a/integrations/access/datadog/testlib/fake_datadog.go
+++ b/integrations/access/datadog/testlib/fake_datadog.go
@@ -32,7 +32,6 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/datadog"
)
@@ -281,6 +280,6 @@ func (d *FakeDatadog) GetOncallTeams() (map[string][]string, bool) {
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/access/discord/bot.go b/integrations/access/discord/bot.go
index ca231bdf83a93..576606998b23c 100644
--- a/integrations/access/discord/bot.go
+++ b/integrations/access/discord/bot.go
@@ -94,8 +94,7 @@ func emitStatusUpdate(resp *resty.Response, statusSink common.StatusSink) {
if err := statusSink.Emit(ctx, status); err != nil {
logger.Get(resp.Request.Context()).
- WithError(err).
- Errorf("Error while emitting Discord plugin status: %v", err)
+ ErrorContext(ctx, "Error while emitting Discord plugin status", "error", err)
}
}
diff --git a/integrations/access/discord/cmd/teleport-discord/main.go b/integrations/access/discord/cmd/teleport-discord/main.go
index cd19ce64591b6..f624b407742ba 100644
--- a/integrations/access/discord/cmd/teleport-discord/main.go
+++ b/integrations/access/discord/cmd/teleport-discord/main.go
@@ -20,6 +20,7 @@ import (
"context"
_ "embed"
"fmt"
+ "log/slog"
"os"
"github.com/alecthomas/kingpin/v2"
@@ -65,12 +66,13 @@ func main() {
if err := run(*path, *debug); err != nil {
lib.Bail(err)
} else {
- logger.Standard().Info("Successfully shut down")
+ slog.InfoContext(context.Background(), "Successfully shut down")
}
}
}
func run(configPath string, debug bool) error {
+ ctx := context.Background()
conf, err := discord.LoadDiscordConfig(configPath)
if err != nil {
return trace.Wrap(err)
@@ -84,14 +86,15 @@ func run(configPath string, debug bool) error {
return trace.Wrap(err)
}
if debug {
- logger.Standard().Debugf("DEBUG logging enabled")
+ slog.DebugContext(ctx, "DEBUG logging enabled")
}
app := discord.NewApp(conf)
go lib.ServeSignals(app, common.PluginShutdownTimeout)
- logger.Standard().Infof("Starting Teleport Access Discord Plugin %s:%s", teleport.Version, teleport.Gitref)
- return trace.Wrap(
- app.Run(context.Background()),
+ slog.InfoContext(ctx, "Starting Teleport Access Discord Plugin",
+ "version", teleport.Version,
+ "git_ref", teleport.Gitref,
)
+ return trace.Wrap(app.Run(ctx))
}
diff --git a/integrations/access/discord/testlib/fake_discord.go b/integrations/access/discord/testlib/fake_discord.go
index c5a176446be5b..0a059d8ac81e2 100644
--- a/integrations/access/discord/testlib/fake_discord.go
+++ b/integrations/access/discord/testlib/fake_discord.go
@@ -32,7 +32,6 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/discord"
)
@@ -188,6 +187,6 @@ func (s *FakeDiscord) CheckMessageUpdateByResponding(ctx context.Context) (disco
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/access/email/app.go b/integrations/access/email/app.go
index 07bb3b558080e..cae9c33ed5315 100644
--- a/integrations/access/email/app.go
+++ b/integrations/access/email/app.go
@@ -18,6 +18,7 @@ package email
import (
"context"
+ "log/slog"
"slices"
"time"
@@ -32,6 +33,7 @@ import (
"github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/teleport/integrations/lib/watcherjob"
"github.com/gravitational/teleport/lib/utils"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
const (
@@ -90,7 +92,6 @@ func (a *App) run(ctx context.Context) error {
var err error
log := logger.Get(ctx)
- log.Infof("Starting Teleport Access Email Plugin")
if err = a.init(ctx); err != nil {
return trace.Wrap(err)
@@ -137,9 +138,9 @@ func (a *App) run(ctx context.Context) error {
a.mainJob.SetReady(ok)
if ok {
- log.Info("Plugin is ready")
+ log.InfoContext(ctx, "Plugin is ready")
} else {
- log.Error("Plugin is not ready")
+ log.ErrorContext(ctx, "Plugin is not ready")
}
<-watcherJob.Done()
@@ -186,24 +187,24 @@ func (a *App) init(ctx context.Context) error {
},
})
- log.Debug("Starting client connection health check...")
+ log.DebugContext(ctx, "Starting client connection health check")
if err = a.client.CheckHealth(ctx); err != nil {
return trace.Wrap(err, "client connection health check failed")
}
- log.Debug("Client connection health check finished ok")
+ log.DebugContext(ctx, "Client connection health check finished ok")
return nil
}
// checkTeleportVersion checks that Teleport version is not lower than required
func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) {
log := logger.Get(ctx)
- log.Debug("Checking Teleport server version")
+ log.DebugContext(ctx, "Checking Teleport server version")
pong, err := a.apiClient.Ping(ctx)
if err != nil {
if trace.IsNotImplemented(err) {
return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion)
}
- log.Error("Unable to get Teleport server version")
+ log.ErrorContext(ctx, "Unable to get Teleport server version")
return pong, trace.Wrap(err)
}
err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion)
@@ -229,16 +230,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
}
op := event.Type
reqID := event.Resource.GetName()
- ctx, _ = logger.WithField(ctx, "request_id", reqID)
+ ctx, _ = logger.With(ctx, "request_id", reqID)
switch op {
case types.OpPut:
- ctx, _ = logger.WithField(ctx, "request_op", "put")
+ ctx, _ = logger.With(ctx, "request_op", "put")
req, ok := event.Resource.(types.AccessRequest)
if !ok {
return trace.Errorf("unexpected resource type %T", event.Resource)
}
- ctx, log := logger.WithField(ctx, "request_state", req.GetState().String())
+ ctx, log := logger.With(ctx, "request_state", req.GetState().String())
var err error
switch {
@@ -249,21 +250,31 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
case req.GetState().IsDenied():
err = a.onResolvedRequest(ctx, req)
default:
- log.WithField("event", event).Warn("Unknown request state")
+ log.WarnContext(ctx, "Unknown request state",
+ slog.Group("event",
+ slog.Any("type", logutils.StringerAttr(event.Type)),
+ slog.Group("resource",
+ "kind", event.Resource.GetKind(),
+ "name", event.Resource.GetName(),
+ ),
+ ),
+ )
+
+ log.With("event", event).WarnContext(ctx, "Unknown request state")
return nil
}
if err != nil {
- log.WithError(err).Errorf("Failed to process request")
+ log.ErrorContext(ctx, "Failed to process request", "error", err)
return trace.Wrap(err)
}
return nil
case types.OpDelete:
- ctx, log := logger.WithField(ctx, "request_op", "delete")
+ ctx, log := logger.With(ctx, "request_op", "delete")
if err := a.onDeletedRequest(ctx, reqID); err != nil {
- log.WithError(err).Errorf("Failed to process deleted request")
+ log.ErrorContext(ctx, "Failed to process deleted request", "error", err)
return trace.Wrap(err)
}
return nil
@@ -292,7 +303,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err
if isNew {
recipients := a.getRecipients(ctx, req)
if len(recipients) == 0 {
- log.Warning("No recipients to send")
+ log.WarnContext(ctx, "No recipients to send")
return nil
}
@@ -329,7 +340,7 @@ func (a *App) onResolvedRequest(ctx context.Context, req types.AccessRequest) er
case types.RequestState_DENIED:
resolution.Tag = ResolvedDenied
default:
- logger.Get(ctx).Warningf("Unknown state %v (%s)", state, state.String())
+ logger.Get(ctx).WarnContext(ctx, "Unknown state", "state", logutils.StringerAttr(state))
return replyErr
}
err := trace.Wrap(a.sendResolution(ctx, req.GetName(), resolution))
@@ -359,7 +370,7 @@ func (a *App) getRecipients(ctx context.Context, req types.AccessRequest) []comm
rawRecipients := a.conf.RoleToRecipients.GetRawRecipientsFor(req.GetRoles(), req.GetSuggestedReviewers())
for _, rawRecipient := range rawRecipients {
if !lib.IsEmail(rawRecipient) {
- log.Warningf("Failed to notify a reviewer: %q does not look like a valid email", rawRecipient)
+ log.WarnContext(ctx, "Failed to notify a suggested reviewer with an invalid email address", "reviewer", rawRecipient)
continue
}
recipientSet.Add(common.Recipient{
@@ -382,7 +393,7 @@ func (a *App) sendNewThreads(ctx context.Context, recipients []common.Recipient,
logSentThreads(ctx, threadsSent, "new threads")
if err != nil {
- logger.Get(ctx).WithError(err).Error("Failed send one or more messages")
+ logger.Get(ctx).ErrorContext(ctx, "Failed send one or more messages", "error", err)
}
_, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) {
@@ -425,7 +436,7 @@ func (a *App) sendReviews(ctx context.Context, reqID string, reqData RequestData
return trace.Wrap(err)
}
if !ok {
- logger.Get(ctx).Debug("Failed to post reply: plugin data is missing")
+ logger.Get(ctx).DebugContext(ctx, "Failed to post reply: plugin data is missing")
return nil
}
reviews := reqReviews[oldCount:]
@@ -439,7 +450,11 @@ func (a *App) sendReviews(ctx context.Context, reqID string, reqData RequestData
if err != nil {
errors = append(errors, err)
}
- logger.Get(ctx).Infof("New review for request %v by %v is %v", reqID, review.Author, review.ProposedState.String())
+ logger.Get(ctx).InfoContext(ctx, "New review for request",
+ "request_id", reqID,
+ "author", review.Author,
+ "state", logutils.StringerAttr(review.ProposedState),
+ )
logSentThreads(ctx, threadsSent, "new review")
}
@@ -473,7 +488,7 @@ func (a *App) sendResolution(ctx context.Context, reqID string, resolution Resol
return trace.Wrap(err)
}
if !ok {
- log.Debug("Failed to update messages: plugin data is missing")
+ log.DebugContext(ctx, "Failed to update messages: plugin data is missing")
return nil
}
@@ -482,7 +497,7 @@ func (a *App) sendResolution(ctx context.Context, reqID string, resolution Resol
threadsSent, err := a.client.SendResolution(ctx, threads, reqID, reqData)
logSentThreads(ctx, threadsSent, "request resolved")
- log.Infof("Marked request as %s and sent emails!", resolution.Tag)
+ log.InfoContext(ctx, "Marked request with resolution and sent emails", "resolution", resolution.Tag)
if err != nil {
return trace.Wrap(err)
@@ -567,10 +582,11 @@ func (a *App) updatePluginData(ctx context.Context, reqID string, data PluginDat
// logSentThreads logs successfully sent emails
func logSentThreads(ctx context.Context, threads []EmailThread, kind string) {
for _, thread := range threads {
- logger.Get(ctx).WithFields(logger.Fields{
- "email": thread.Email,
- "timestamp": thread.Timestamp,
- "message_id": thread.MessageID,
- }).Infof("Successfully sent %v!", kind)
+ logger.Get(ctx).InfoContext(ctx, "Successfully sent",
+ "email", thread.Email,
+ "timestamp", thread.Timestamp,
+ "message_id", thread.MessageID,
+ "kind", kind,
+ )
}
}
diff --git a/integrations/access/email/client.go b/integrations/access/email/client.go
index 6ef1d2f04144e..f687f5deb0009 100644
--- a/integrations/access/email/client.go
+++ b/integrations/access/email/client.go
@@ -61,16 +61,16 @@ func NewClient(ctx context.Context, conf Config, clusterName, webProxyAddr strin
if conf.Mailgun != nil {
mailer = NewMailgunMailer(*conf.Mailgun, conf.StatusSink, conf.Delivery.Sender, clusterName, conf.RoleToRecipients[types.Wildcard])
- logger.Get(ctx).WithField("domain", conf.Mailgun.Domain).Info("Using Mailgun as email transport")
+ logger.Get(ctx).InfoContext(ctx, "Using Mailgun as email transport", "domain", conf.Mailgun.Domain)
}
if conf.SMTP != nil {
mailer = NewSMTPMailer(*conf.SMTP, conf.StatusSink, conf.Delivery.Sender, clusterName)
- logger.Get(ctx).WithFields(logger.Fields{
- "host": conf.SMTP.Host,
- "port": conf.SMTP.Port,
- "username": conf.SMTP.Username,
- }).Info("Using SMTP as email transport")
+ logger.Get(ctx).InfoContext(ctx, "Using SMTP as email transport",
+ "host", conf.SMTP.Host,
+ "port", conf.SMTP.Port,
+ "username", conf.SMTP.Username,
+ )
}
return Client{
diff --git a/integrations/access/email/cmd/teleport-email/main.go b/integrations/access/email/cmd/teleport-email/main.go
index 840c80da76177..ccaec3acbed36 100644
--- a/integrations/access/email/cmd/teleport-email/main.go
+++ b/integrations/access/email/cmd/teleport-email/main.go
@@ -20,6 +20,7 @@ import (
"context"
_ "embed"
"fmt"
+ "log/slog"
"os"
"github.com/alecthomas/kingpin/v2"
@@ -65,12 +66,13 @@ func main() {
if err := run(*path, *debug); err != nil {
lib.Bail(err)
} else {
- logger.Standard().Info("Successfully shut down")
+ slog.InfoContext(context.Background(), "Successfully shut down")
}
}
}
func run(configPath string, debug bool) error {
+ ctx := context.Background()
conf, err := email.LoadConfig(configPath)
if err != nil {
return trace.Wrap(err)
@@ -84,11 +86,11 @@ func run(configPath string, debug bool) error {
return err
}
if debug {
- logger.Standard().Debugf("DEBUG logging enabled")
+ slog.DebugContext(ctx, "DEBUG logging enabled")
}
if conf.Delivery.Recipients != nil {
- logger.Standard().Warn("The delivery.recipients config option is deprecated, set role_to_recipients[\"*\"] instead for the same functionality")
+ slog.WarnContext(ctx, "The delivery.recipients config option is deprecated, set role_to_recipients[\"*\"] instead for the same functionality")
}
app, err := email.NewApp(*conf)
@@ -98,8 +100,9 @@ func run(configPath string, debug bool) error {
go lib.ServeSignals(app, common.PluginShutdownTimeout)
- logger.Standard().Infof("Starting Teleport Access Email Plugin %s:%s", teleport.Version, teleport.Gitref)
- return trace.Wrap(
- app.Run(context.Background()),
+ slog.InfoContext(ctx, "Starting Teleport Access Email Plugin",
+ "version", teleport.Version,
+ "git_ref", teleport.Gitref,
)
+ return trace.Wrap(app.Run(ctx))
}
diff --git a/integrations/access/email/mailers.go b/integrations/access/email/mailers.go
index 60d5b4592449f..5cbd3d98bee02 100644
--- a/integrations/access/email/mailers.go
+++ b/integrations/access/email/mailers.go
@@ -114,7 +114,7 @@ func (m *SMTPMailer) CheckHealth(ctx context.Context) error {
return trace.Wrap(err)
}
if err := client.Close(); err != nil {
- log.Debug("Failed to close client connection after health check")
+ log.DebugContext(ctx, "Failed to close client connection after health check")
}
return nil
}
@@ -191,7 +191,7 @@ func (m *SMTPMailer) emitStatus(ctx context.Context, statusErr error) {
code = http.StatusInternalServerError
}
if err := m.sink.Emit(ctx, common.StatusFromStatusCode(code)); err != nil {
- log.WithError(err).Error("Error while emitting Email plugin status")
+ log.ErrorContext(ctx, "Error while emitting Email plugin status", "error", err)
}
}
@@ -252,7 +252,7 @@ func (t *statusSinkTransport) RoundTrip(req *http.Request) (*http.Response, erro
status := common.StatusFromStatusCode(resp.StatusCode)
if err := t.sink.Emit(ctx, status); err != nil {
- log.WithError(err).Error("Error while emitting Email plugin status")
+ log.ErrorContext(ctx, "Error while emitting Email plugin status", "error", err)
}
}
return resp, nil
diff --git a/integrations/access/email/testlib/mock_mailgun.go b/integrations/access/email/testlib/mock_mailgun.go
index 58cbbc8ebb098..7895a5cdcaefe 100644
--- a/integrations/access/email/testlib/mock_mailgun.go
+++ b/integrations/access/email/testlib/mock_mailgun.go
@@ -24,7 +24,6 @@ import (
"github.com/google/uuid"
"github.com/gravitational/trace"
- log "github.com/sirupsen/logrus"
)
const (
@@ -58,7 +57,8 @@ func newMockMailgunServer(concurrency int) *mockMailgunServer {
s := httptest.NewUnstartedServer(func(mg *mockMailgunServer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(multipartFormBufSize); err != nil {
- log.Error(err)
+ w.WriteHeader(http.StatusInternalServerError)
+ return
}
id := uuid.New().String()
diff --git a/integrations/access/jira/app.go b/integrations/access/jira/app.go
index 2aab94e887f0d..c8e6c8273ec02 100644
--- a/integrations/access/jira/app.go
+++ b/integrations/access/jira/app.go
@@ -21,6 +21,7 @@ package jira
import (
"context"
"fmt"
+ "log/slog"
"net/url"
"regexp"
"strings"
@@ -40,6 +41,7 @@ import (
"github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/teleport/integrations/lib/watcherjob"
"github.com/gravitational/teleport/lib/utils"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
const (
@@ -125,7 +127,6 @@ func (a *App) run(ctx context.Context) error {
var err error
log := logger.Get(ctx)
- log.Infof("Starting Teleport Jira Plugin")
if err = a.init(ctx); err != nil {
return trace.Wrap(err)
@@ -164,9 +165,9 @@ func (a *App) run(ctx context.Context) error {
ok := (a.webhookSrv == nil || httpOk) && watcherOk
a.mainJob.SetReady(ok)
if ok {
- log.Info("Plugin is ready")
+ log.InfoContext(ctx, "Plugin is ready")
} else {
- log.Error("Plugin is not ready")
+ log.ErrorContext(ctx, "Plugin is not ready")
}
if httpJob != nil {
@@ -205,11 +206,11 @@ func (a *App) init(ctx context.Context) error {
return trace.Wrap(err)
}
- log.Debug("Starting Jira API health check...")
+ log.DebugContext(ctx, "Starting Jira API health check")
if err = a.jira.HealthCheck(ctx); err != nil {
return trace.Wrap(err, "api health check failed")
}
- log.Debug("Jira API health check finished ok")
+ log.DebugContext(ctx, "Jira API health check finished ok")
if !a.conf.DisableWebhook {
webhookSrv, err := NewWebhookServer(a.conf.HTTP, a.onJiraWebhook)
@@ -227,13 +228,13 @@ func (a *App) init(ctx context.Context) error {
func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) {
log := logger.Get(ctx)
- log.Debug("Checking Teleport server version")
+ log.DebugContext(ctx, "Checking Teleport server version")
pong, err := a.teleport.Ping(ctx)
if err != nil {
if trace.IsNotImplemented(err) {
return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion)
}
- log.Error("Unable to get Teleport server version")
+ log.ErrorContext(ctx, "Unable to get Teleport server version")
return pong, trace.Wrap(err)
}
err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion)
@@ -246,17 +247,17 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error {
}
op := event.Type
reqID := event.Resource.GetName()
- ctx, _ = logger.WithField(ctx, "request_id", reqID)
+ ctx, _ = logger.With(ctx, "request_id", reqID)
switch op {
case types.OpPut:
- ctx, _ = logger.WithField(ctx, "request_op", "put")
+ ctx, _ = logger.With(ctx, "request_op", "put")
req, ok := event.Resource.(types.AccessRequest)
if !ok {
return trace.Errorf("unexpected resource type %T", event.Resource)
}
- ctx, log := logger.WithField(ctx, "request_state", req.GetState().String())
- log.Debug("Processing watcher event")
+ ctx, log := logger.With(ctx, "request_state", req.GetState().String())
+ log.DebugContext(ctx, "Processing watcher event")
var err error
switch {
@@ -265,21 +266,29 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error {
case req.GetState().IsResolved():
err = a.onResolvedRequest(ctx, req)
default:
- log.WithField("event", event).Warn("Unknown request state")
+ log.WarnContext(ctx, "Unknown request state",
+ slog.Group("event",
+ slog.Any("type", logutils.StringerAttr(event.Type)),
+ slog.Group("resource",
+ "kind", event.Resource.GetKind(),
+ "name", event.Resource.GetName(),
+ ),
+ ),
+ )
return nil
}
if err != nil {
- log.WithError(err).Error("Failed to process request")
+ log.ErrorContext(ctx, "Failed to process request", "error", err)
return trace.Wrap(err)
}
return nil
case types.OpDelete:
- ctx, log := logger.WithField(ctx, "request_op", "delete")
+ ctx, log := logger.With(ctx, "request_op", "delete")
if err := a.onDeletedRequest(ctx, reqID); err != nil {
- log.WithError(err).Errorf("Failed to process deleted request")
+ log.ErrorContext(ctx, "Failed to process deleted request", "error", err)
return trace.Wrap(err)
}
return nil
@@ -299,10 +308,11 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error {
return nil
}
- ctx, log := logger.WithFields(ctx, logger.Fields{
- "jira_issue_id": webhook.Issue.ID,
- })
- log.Debugf("Processing incoming webhook event %q with type %q", webhookEvent, issueEventTypeName)
+ ctx, log := logger.With(ctx, "jira_issue_id", webhook.Issue.ID)
+ log.DebugContext(ctx, "Processing incoming webhook event",
+ "event", webhookEvent,
+ "event_type", issueEventTypeName,
+ )
if webhook.Issue == nil {
return trace.Errorf("got webhook without issue info")
@@ -333,20 +343,20 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error {
if statusName == "" {
return trace.Errorf("getting Jira issue status: %w", err)
}
- log.Warnf("Using most recent successful getIssue response: %v", err)
+ log.WarnContext(ctx, "Using most recent successful getIssue response", "error", err)
}
- ctx, log = logger.WithFields(ctx, logger.Fields{
- "jira_issue_id": issue.ID,
- "jira_issue_key": issue.Key,
- })
+ ctx, log = logger.With(ctx,
+ "jira_issue_id", issue.ID,
+ "jira_issue_key", issue.Key,
+ )
switch {
case statusName == "pending":
- log.Debug("Issue has pending status, ignoring it")
+ log.DebugContext(ctx, "Issue has pending status, ignoring it")
return nil
case statusName == "expired":
- log.Debug("Issue has expired status, ignoring it")
+ log.DebugContext(ctx, "Issue has expired status, ignoring it")
return nil
case statusName != "approved" && statusName != "denied":
return trace.BadParameter("unknown Jira status %s", statusName)
@@ -357,11 +367,11 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error {
return trace.Wrap(err)
}
if reqID == "" {
- log.Debugf("Missing %q issue property", RequestIDPropertyKey)
+ log.DebugContext(ctx, "Missing teleportAccessRequestId issue property")
return nil
}
- ctx, log = logger.WithField(ctx, "request_id", reqID)
+ ctx, log = logger.With(ctx, "request_id", reqID)
reqs, err := a.teleport.GetAccessRequests(ctx, types.AccessRequestFilter{ID: reqID})
if err != nil {
@@ -382,8 +392,9 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error {
return trace.Errorf("plugin data is blank")
}
if pluginData.IssueID != issue.ID {
- log.WithField("plugin_data_issue_id", pluginData.IssueID).
- Debug("plugin_data.issue_id does not match issue.id")
+ log.DebugContext(ctx, "plugin_data.issue_id does not match issue.id",
+ "plugin_data_issue_id", pluginData.IssueID,
+ )
return trace.Errorf("issue_id from request's plugin_data does not match")
}
@@ -406,17 +417,17 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error {
author, reason, err := a.loadResolutionInfo(ctx, issue, statusName)
if err != nil {
- log.WithError(err).Error("Failed to load resolution info from the issue history")
+ log.ErrorContext(ctx, "Failed to load resolution info from the issue history", "error", err)
}
resolution.Reason = reason
- ctx, _ = logger.WithFields(ctx, logger.Fields{
- "jira_user_email": author.EmailAddress,
- "jira_user_name": author.DisplayName,
- "request_user": req.GetUser(),
- "request_roles": req.GetRoles(),
- "reason": reason,
- })
+ ctx, _ = logger.With(ctx,
+ "jira_user_email", author.EmailAddress,
+ "jira_user_name", author.DisplayName,
+ "request_user", req.GetUser(),
+ "request_roles", req.GetRoles(),
+ "reason", reason,
+ )
if err := a.resolveRequest(ctx, reqID, author.EmailAddress, resolution); err != nil {
return trace.Wrap(err)
}
@@ -498,11 +509,11 @@ func (a *App) createIssue(ctx context.Context, reqID string, reqData RequestData
return trace.Wrap(err)
}
- ctx, log := logger.WithFields(ctx, logger.Fields{
- "jira_issue_id": data.IssueID,
- "jira_issue_key": data.IssueKey,
- })
- log.Info("Jira Issue created")
+ ctx, log := logger.With(ctx,
+ "jira_issue_id", data.IssueID,
+ "jira_issue_key", data.IssueKey,
+ )
+ log.InfoContext(ctx, "Jira Issue created")
// Save jira issue info in plugin data.
_, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) {
@@ -551,11 +562,11 @@ func (a *App) addReviewComments(ctx context.Context, reqID string, reqReviews []
}
if !ok {
if issueID == "" {
- logger.Get(ctx).Debug("Failed to add the comment: plugin data is blank")
+ logger.Get(ctx).DebugContext(ctx, "Failed to add the comment: plugin data is blank")
}
return nil
}
- ctx, _ = logger.WithField(ctx, "jira_issue_id", issueID)
+ ctx, _ = logger.With(ctx, "jira_issue_id", issueID)
slice := reqReviews[oldCount:]
if len(slice) == 0 {
@@ -621,7 +632,7 @@ func (a *App) resolveRequest(ctx context.Context, reqID string, userEmail string
return trace.Wrap(err)
}
- logger.Get(ctx).Infof("Jira user %s the request", resolution.Tag)
+ logger.Get(ctx).InfoContext(ctx, "Jira user processed the request", "resolution", resolution.Tag)
return nil
}
@@ -658,18 +669,18 @@ func (a *App) resolveIssue(ctx context.Context, reqID string, resolution Resolut
}
if !ok {
if issueID == "" {
- logger.Get(ctx).Debug("Failed to resolve the issue: plugin data is blank")
+ logger.Get(ctx).DebugContext(ctx, "Failed to resolve the issue: plugin data is blank")
}
// Either plugin data is missing or issue is already resolved by us, just quit.
return nil
}
- ctx, log := logger.WithField(ctx, "jira_issue_id", issueID)
+ ctx, log := logger.With(ctx, "jira_issue_id", issueID)
if err := a.jira.ResolveIssue(ctx, issueID, resolution); err != nil {
return trace.Wrap(err)
}
- log.Info("Successfully resolved the issue")
+ log.InfoContext(ctx, "Successfully resolved the issue")
return nil
}
diff --git a/integrations/access/jira/client.go b/integrations/access/jira/client.go
index 2877966af663b..a23381e4d2666 100644
--- a/integrations/access/jira/client.go
+++ b/integrations/access/jira/client.go
@@ -125,7 +125,7 @@ func NewJiraClient(conf JiraConfig, clusterName, teleportProxyAddr string, statu
defer cancel()
if err := statusSink.Emit(ctx, status); err != nil {
- log.WithError(err).Errorf("Error while emitting Jira plugin status: %v", err)
+ log.ErrorContext(ctx, "Error while emitting Jira plugin status", "error", err)
}
}
@@ -199,7 +199,7 @@ func (j *Jira) HealthCheck(ctx context.Context) error {
}
}
- log.Debug("Checking out Jira project...")
+ log.DebugContext(ctx, "Checking out Jira project")
var project Project
_, err = j.client.NewRequest().
SetContext(ctx).
@@ -209,9 +209,12 @@ func (j *Jira) HealthCheck(ctx context.Context) error {
if err != nil {
return trace.Wrap(err)
}
- log.Debugf("Found project %q named %q", project.Key, project.Name)
+ log.DebugContext(ctx, "Found Jira project",
+ "project", project.Key,
+ "project_name", project.Name,
+ )
- log.Debug("Checking out Jira project permissions...")
+ log.DebugContext(ctx, "Checking out Jira project permissions")
queryOptions, err := query.Values(GetMyPermissionsQueryOptions{
ProjectKey: j.project,
Permissions: jiraRequiredPermissions,
@@ -433,7 +436,7 @@ func (j *Jira) ResolveIssue(ctx context.Context, issueID string, resolution Reso
if err2 := trace.Wrap(j.TransitionIssue(ctx, issue.ID, transition.ID)); err2 != nil {
return trace.NewAggregate(err1, err2)
}
- logger.Get(ctx).Debugf("Successfully moved the issue to the status %q", toStatus)
+ logger.Get(ctx).DebugContext(ctx, "Successfully moved the issue to the target status", "target_status", toStatus)
return trace.Wrap(err1)
}
@@ -457,7 +460,7 @@ func (j *Jira) AddResolutionComment(ctx context.Context, id string, resolution R
SetBody(CommentInput{Body: builder.String()}).
Post("rest/api/2/issue/{issueID}/comment")
if err == nil {
- logger.Get(ctx).Debug("Successfully added a resolution comment to the issue")
+ logger.Get(ctx).DebugContext(ctx, "Successfully added a resolution comment to the issue")
}
return trace.Wrap(err)
}
diff --git a/integrations/access/jira/cmd/teleport-jira/main.go b/integrations/access/jira/cmd/teleport-jira/main.go
index b2c2bb0672d06..851de27473296 100644
--- a/integrations/access/jira/cmd/teleport-jira/main.go
+++ b/integrations/access/jira/cmd/teleport-jira/main.go
@@ -20,6 +20,7 @@ import (
"context"
_ "embed"
"fmt"
+ "log/slog"
"os"
"github.com/alecthomas/kingpin/v2"
@@ -72,12 +73,13 @@ func main() {
if err := run(*path, *insecure, *debug); err != nil {
lib.Bail(err)
} else {
- logger.Standard().Info("Successfully shut down")
+ slog.InfoContext(context.Background(), "Successfully shut down")
}
}
}
func run(configPath string, insecure bool, debug bool) error {
+ ctx := context.Background()
conf, err := jira.LoadConfig(configPath)
if err != nil {
return trace.Wrap(err)
@@ -91,7 +93,7 @@ func run(configPath string, insecure bool, debug bool) error {
return err
}
if debug {
- logger.Standard().Debugf("DEBUG logging enabled")
+ slog.DebugContext(ctx, "DEBUG logging enabled")
}
conf.HTTP.Insecure = insecure
@@ -102,8 +104,9 @@ func run(configPath string, insecure bool, debug bool) error {
go lib.ServeSignals(app, common.PluginShutdownTimeout)
- logger.Standard().Infof("Starting Teleport Access Jira Plugin %s:%s", teleport.Version, teleport.Gitref)
- return trace.Wrap(
- app.Run(context.Background()),
+ slog.InfoContext(ctx, "Starting Teleport Access Jira Plugin",
+ "version", teleport.Version,
+ "git_ref", teleport.Gitref,
)
+ return trace.Wrap(app.Run(ctx))
}
diff --git a/integrations/access/jira/testlib/fake_jira.go b/integrations/access/jira/testlib/fake_jira.go
index 1da8c432ec3a9..9696500620aba 100644
--- a/integrations/access/jira/testlib/fake_jira.go
+++ b/integrations/access/jira/testlib/fake_jira.go
@@ -30,7 +30,6 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/jira"
)
@@ -304,6 +303,6 @@ func (s *FakeJira) CheckIssueTransition(ctx context.Context) (jira.Issue, error)
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/access/jira/testlib/suite.go b/integrations/access/jira/testlib/suite.go
index 38341d589fa5d..c2a3d421f442c 100644
--- a/integrations/access/jira/testlib/suite.go
+++ b/integrations/access/jira/testlib/suite.go
@@ -721,7 +721,7 @@ func (s *JiraSuiteOSS) TestRace() {
defer cancel()
var lastErr error
for {
- logger.Get(ctx).Infof("Trying to approve issue %q", issue.Key)
+ logger.Get(ctx).InfoContext(ctx, "Trying to approve issue", "issue_key", issue.Key)
resp, err := s.postWebhook(ctx, s.webhookURL.String(), issue.ID, "Approved")
if err != nil {
if lib.IsDeadline(err) {
diff --git a/integrations/access/jira/webhook_server.go b/integrations/access/jira/webhook_server.go
index b83e449b992c8..e9e409959b40a 100644
--- a/integrations/access/jira/webhook_server.go
+++ b/integrations/access/jira/webhook_server.go
@@ -105,29 +105,31 @@ func (s *WebhookServer) processWebhook(rw http.ResponseWriter, r *http.Request,
defer cancel()
httpRequestID := fmt.Sprintf("%v-%v", time.Now().Unix(), atomic.AddUint64(&s.counter, 1))
- ctx, log := logger.WithField(ctx, "jira_http_id", httpRequestID)
+ ctx, log := logger.With(ctx, "jira_http_id", httpRequestID)
var webhook Webhook
body, err := io.ReadAll(io.LimitReader(r.Body, jiraWebhookPayloadLimit+1))
if err != nil {
- log.WithError(err).Error("Failed to read webhook payload")
+ log.ErrorContext(ctx, "Failed to read webhook payload", "error", err)
http.Error(rw, "", http.StatusInternalServerError)
return
}
if len(body) > jiraWebhookPayloadLimit {
- log.Error("Received a webhook larger than %d bytes", jiraWebhookPayloadLimit)
+ log.ErrorContext(ctx, "Received a webhook with a payload that exceeded the limit",
+ "payload_size", len(body),
+ "payload_size_limit", jiraWebhookPayloadLimit,
+ )
http.Error(rw, "", http.StatusRequestEntityTooLarge)
}
if err = json.Unmarshal(body, &webhook); err != nil {
- log.WithError(err).Error("Failed to parse webhook payload")
+ log.ErrorContext(ctx, "Failed to parse webhook payload", "error", err)
http.Error(rw, "", http.StatusBadRequest)
return
}
if err = s.onWebhook(ctx, webhook); err != nil {
- log.WithError(err).Error("Failed to process webhook")
- log.Debugf("%v", trace.DebugReport(err))
+ log.ErrorContext(ctx, "Failed to process webhook", "error", err)
var code int
switch {
case lib.IsCanceled(err) || lib.IsDeadline(err):
diff --git a/integrations/access/mattermost/bot.go b/integrations/access/mattermost/bot.go
index c7de9d0aaae44..edf0a7e73264d 100644
--- a/integrations/access/mattermost/bot.go
+++ b/integrations/access/mattermost/bot.go
@@ -150,7 +150,7 @@ func NewBot(conf Config, clusterName, webProxyAddr string) (Bot, error) {
ctx, cancel := context.WithTimeout(context.Background(), mmStatusEmitTimeout)
defer cancel()
if err := sink.Emit(ctx, status); err != nil {
- log.Errorf("Error while emitting plugin status: %v", err)
+ log.ErrorContext(ctx, "Error while emitting plugin status", "error", err)
}
}()
@@ -463,14 +463,14 @@ func (b Bot) buildPostText(reqID string, reqData pd.AccessRequestData) (string,
}
func (b Bot) tryLookupDirectChannel(ctx context.Context, userEmail string) string {
- log := logger.Get(ctx).WithField("mm_user_email", userEmail)
+ log := logger.Get(ctx).With("mm_user_email", userEmail)
channel, err := b.LookupDirectChannel(ctx, userEmail)
if err != nil {
var errResult *ErrorResult
if errors.As(trace.Unwrap(err), &errResult) {
- log.Warningf("Failed to lookup direct channel info: %q", errResult.Message)
+ log.WarnContext(ctx, "Failed to lookup direct channel info", "error", errResult.Message)
} else {
- log.WithError(err).Error("Failed to lookup direct channel info")
+ log.ErrorContext(ctx, "Failed to lookup direct channel info", "error", err)
}
return ""
}
@@ -478,17 +478,17 @@ func (b Bot) tryLookupDirectChannel(ctx context.Context, userEmail string) strin
}
func (b Bot) tryLookupChannel(ctx context.Context, team, name string) string {
- log := logger.Get(ctx).WithFields(logger.Fields{
- "mm_team": team,
- "mm_channel": name,
- })
+ log := logger.Get(ctx).With(
+ "mm_team", team,
+ "mm_channel", name,
+ )
channel, err := b.LookupChannel(ctx, team, name)
if err != nil {
var errResult *ErrorResult
if errors.As(trace.Unwrap(err), &errResult) {
- log.Warningf("Failed to lookup channel info: %q", errResult.Message)
+ log.WarnContext(ctx, "Failed to lookup channel info", "error", errResult.Message)
} else {
- log.WithError(err).Error("Failed to lookup channel info")
+ log.ErrorContext(ctx, "Failed to lookup channel info", "error", err)
}
return ""
}
diff --git a/integrations/access/mattermost/cmd/teleport-mattermost/main.go b/integrations/access/mattermost/cmd/teleport-mattermost/main.go
index 7c4777b26655b..0c67abb62ef86 100644
--- a/integrations/access/mattermost/cmd/teleport-mattermost/main.go
+++ b/integrations/access/mattermost/cmd/teleport-mattermost/main.go
@@ -20,6 +20,7 @@ import (
"context"
_ "embed"
"fmt"
+ "log/slog"
"os"
"github.com/alecthomas/kingpin/v2"
@@ -65,12 +66,13 @@ func main() {
if err := run(*path, *debug); err != nil {
lib.Bail(err)
} else {
- logger.Standard().Info("Successfully shut down")
+ slog.InfoContext(context.Background(), "Successfully shut down")
}
}
}
func run(configPath string, debug bool) error {
+ ctx := context.Background()
conf, err := mattermost.LoadConfig(configPath)
if err != nil {
return trace.Wrap(err)
@@ -84,14 +86,15 @@ func run(configPath string, debug bool) error {
return err
}
if debug {
- logger.Standard().Debugf("DEBUG logging enabled")
+ slog.DebugContext(ctx, "DEBUG logging enabled")
}
app := mattermost.NewMattermostApp(conf)
go lib.ServeSignals(app, common.PluginShutdownTimeout)
- logger.Standard().Infof("Starting Teleport Access Mattermost Plugin %s:%s", teleport.Version, teleport.Gitref)
- return trace.Wrap(
- app.Run(context.Background()),
+ slog.InfoContext(ctx, "Starting Teleport Access Mattermost Plugin",
+ "version", teleport.Version,
+ "git_ref", teleport.Gitref,
)
+ return trace.Wrap(app.Run(ctx))
}
diff --git a/integrations/access/mattermost/testlib/fake_mattermost.go b/integrations/access/mattermost/testlib/fake_mattermost.go
index 10cc048e743bd..b2c28287c6153 100644
--- a/integrations/access/mattermost/testlib/fake_mattermost.go
+++ b/integrations/access/mattermost/testlib/fake_mattermost.go
@@ -31,7 +31,6 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/mattermost"
)
@@ -387,6 +386,6 @@ func (s *FakeMattermost) CheckPostUpdate(ctx context.Context) (mattermost.Post,
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/access/msteams/app.go b/integrations/access/msteams/app.go
index 306be091ca8b0..b18c96ba3f4a3 100644
--- a/integrations/access/msteams/app.go
+++ b/integrations/access/msteams/app.go
@@ -62,14 +62,9 @@ type App struct {
// NewApp initializes a new teleport-msteams app and returns it.
func NewApp(conf Config) (*App, error) {
- log, err := conf.Log.NewSLogLogger()
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
app := &App{
conf: conf,
- log: log.With("plugin", pluginName),
+ log: slog.With("plugin", pluginName),
}
app.mainJob = lib.NewServiceJob(app.run)
diff --git a/integrations/access/msteams/bot.go b/integrations/access/msteams/bot.go
index c0598c1f4d24f..4292f856dba90 100644
--- a/integrations/access/msteams/bot.go
+++ b/integrations/access/msteams/bot.go
@@ -30,7 +30,6 @@ import (
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/access/msteams/msapi"
"github.com/gravitational/teleport/integrations/lib"
- "github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/teleport/integrations/lib/plugindata"
)
@@ -469,7 +468,7 @@ func (b *Bot) CheckHealth(ctx context.Context) error {
Code: status,
ErrorMessage: message,
}); err != nil {
- logger.Get(ctx).Errorf("Error while emitting ms teams plugin status: %v", err)
+ b.log.ErrorContext(ctx, "Error while emitting ms teams plugin status", "error", err)
}
}
return trace.Wrap(err)
diff --git a/integrations/access/msteams/cmd/teleport-msteams/main.go b/integrations/access/msteams/cmd/teleport-msteams/main.go
index 970df1ac98db4..75e66a46b7cf7 100644
--- a/integrations/access/msteams/cmd/teleport-msteams/main.go
+++ b/integrations/access/msteams/cmd/teleport-msteams/main.go
@@ -16,6 +16,7 @@ package main
import (
"context"
+ "log/slog"
"os"
"time"
@@ -99,7 +100,7 @@ func main() {
if err := run(*startConfigPath, *debug); err != nil {
lib.Bail(err)
} else {
- logger.Standard().Info("Successfully shut down")
+ slog.InfoContext(context.Background(), "Successfully shut down")
}
}
diff --git a/integrations/access/msteams/testlib/fake_msteams.go b/integrations/access/msteams/testlib/fake_msteams.go
index ceb1a3edc2d41..f3e4d4c5550c2 100644
--- a/integrations/access/msteams/testlib/fake_msteams.go
+++ b/integrations/access/msteams/testlib/fake_msteams.go
@@ -30,7 +30,6 @@ import (
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/msteams/msapi"
)
@@ -326,6 +325,6 @@ func (s *FakeTeams) CheckMessageUpdate(ctx context.Context) (Msg, error) {
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/access/msteams/uninstall.go b/integrations/access/msteams/uninstall.go
index e60a9ce0c8ddd..22aa9e6961ab1 100644
--- a/integrations/access/msteams/uninstall.go
+++ b/integrations/access/msteams/uninstall.go
@@ -18,7 +18,8 @@ import (
"context"
"github.com/gravitational/trace"
- log "github.com/sirupsen/logrus"
+
+ "github.com/gravitational/teleport/integrations/lib/logger"
)
func Uninstall(ctx context.Context, configPath string) error {
@@ -26,11 +27,13 @@ func Uninstall(ctx context.Context, configPath string) error {
if err != nil {
return trace.Wrap(err)
}
- err = checkApp(ctx, b)
- if err != nil {
+
+ if err := checkApp(ctx, b); err != nil {
return trace.Wrap(err)
}
+ log := logger.Get(ctx)
+
var errs []error
for _, recipient := range c.Recipients.GetAllRawRecipients() {
_, isChannel := b.checkChannelURL(recipient)
@@ -38,11 +41,11 @@ func Uninstall(ctx context.Context, configPath string) error {
errs = append(errs, b.UninstallAppForUser(ctx, recipient))
}
}
- err = trace.NewAggregate(errs...)
- if err != nil {
- log.Errorln("The following error(s) happened when uninstalling the Teams App:")
+
+ if trace.NewAggregate(errs...) != nil {
+ log.ErrorContext(ctx, "Encountered error(s) when uninstalling the Teams App", "error", err)
return err
}
- log.Info("Successfully uninstalled app for all recipients")
+ log.InfoContext(ctx, "Successfully uninstalled app for all recipients")
return nil
}
diff --git a/integrations/access/msteams/validate.go b/integrations/access/msteams/validate.go
index 61d9d25f635e8..7969d7edebe0d 100644
--- a/integrations/access/msteams/validate.go
+++ b/integrations/access/msteams/validate.go
@@ -17,6 +17,7 @@ package msteams
import (
"context"
"fmt"
+ "log/slog"
"time"
cards "github.com/DanielTitkov/go-adaptive-cards"
@@ -142,11 +143,7 @@ func loadConfig(configPath string) (*Bot, *Config, error) {
fmt.Printf(" - Checking application %v status...\n", c.MSAPI.TeamsAppID)
- log, err := c.Log.NewSLogLogger()
- if err != nil {
- return nil, nil, trace.Wrap(err)
- }
- b, err := NewBot(c, "local", "", log)
+ b, err := NewBot(c, "local", "", slog.Default())
if err != nil {
return nil, nil, trace.Wrap(err)
}
diff --git a/integrations/access/opsgenie/app.go b/integrations/access/opsgenie/app.go
index 132389ad5b5a3..60950f31fa4b1 100644
--- a/integrations/access/opsgenie/app.go
+++ b/integrations/access/opsgenie/app.go
@@ -22,6 +22,7 @@ import (
"context"
"errors"
"fmt"
+ "log/slog"
"strings"
"time"
@@ -39,6 +40,7 @@ import (
"github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/teleport/integrations/lib/watcherjob"
"github.com/gravitational/teleport/lib/utils"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
const (
@@ -115,7 +117,7 @@ func (a *App) run(ctx context.Context) error {
var err error
log := logger.Get(ctx)
- log.Infof("Starting Teleport Access Opsgenie Plugin")
+ log.InfoContext(ctx, "Starting Teleport Access Opsgenie Plugin")
if err = a.init(ctx); err != nil {
return trace.Wrap(err)
@@ -147,9 +149,9 @@ func (a *App) run(ctx context.Context) error {
a.mainJob.SetReady(ok)
if ok {
- log.Info("Plugin is ready")
+ log.InfoContext(ctx, "Plugin is ready")
} else {
- log.Error("Plugin is not ready")
+ log.ErrorContext(ctx, "Plugin is not ready")
}
<-watcherJob.Done()
@@ -177,24 +179,24 @@ func (a *App) init(ctx context.Context) error {
}
log := logger.Get(ctx)
- log.Debug("Starting API health check...")
+ log.DebugContext(ctx, "Starting API health check")
if err = a.opsgenie.CheckHealth(ctx); err != nil {
return trace.Wrap(err, "API health check failed")
}
- log.Debug("API health check finished ok")
+ log.DebugContext(ctx, "API health check finished ok")
return nil
}
func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) {
log := logger.Get(ctx)
- log.Debug("Checking Teleport server version")
+ log.DebugContext(ctx, "Checking Teleport server version")
pong, err := a.teleport.Ping(ctx)
if err != nil {
if trace.IsNotImplemented(err) {
return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion)
}
- log.Error("Unable to get Teleport server version")
+ log.ErrorContext(ctx, "Unable to get Teleport server version")
return pong, trace.Wrap(err)
}
err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion)
@@ -219,16 +221,16 @@ func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error {
}
op := event.Type
reqID := event.Resource.GetName()
- ctx, _ = logger.WithField(ctx, "request_id", reqID)
+ ctx, _ = logger.With(ctx, "request_id", reqID)
switch op {
case types.OpPut:
- ctx, _ = logger.WithField(ctx, "request_op", "put")
+ ctx, _ = logger.With(ctx, "request_op", "put")
req, ok := event.Resource.(types.AccessRequest)
if !ok {
return trace.Errorf("unexpected resource type %T", event.Resource)
}
- ctx, log := logger.WithField(ctx, "request_state", req.GetState().String())
+ ctx, log := logger.With(ctx, "request_state", req.GetState().String())
var err error
switch {
@@ -237,21 +239,29 @@ func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error {
case req.GetState().IsResolved():
err = a.onResolvedRequest(ctx, req)
default:
- log.WithField("event", event).Warn("Unknown request state")
+ log.WarnContext(ctx, "Unknown request state",
+ slog.Group("event",
+ slog.Any("type", logutils.StringerAttr(event.Type)),
+ slog.Group("resource",
+ "kind", event.Resource.GetKind(),
+ "name", event.Resource.GetName(),
+ ),
+ ),
+ )
return nil
}
if err != nil {
- log.WithError(err).Error("Failed to process request")
+ log.ErrorContext(ctx, "Failed to process request", "error", err)
return trace.Wrap(err)
}
return nil
case types.OpDelete:
- ctx, log := logger.WithField(ctx, "request_op", "delete")
+ ctx, log := logger.With(ctx, "request_op", "delete")
if err := a.onDeletedRequest(ctx, reqID); err != nil {
- log.WithError(err).Error("Failed to process deleted request")
+ log.ErrorContext(ctx, "Failed to process deleted request", "error", err)
return trace.Wrap(err)
}
return nil
@@ -310,13 +320,13 @@ func (a *App) getNotifySchedulesAndTeams(ctx context.Context, req types.AccessRe
scheduleAnnotationKey := types.TeleportNamespace + types.ReqAnnotationNotifySchedulesLabel
schedules, err = common.GetNamesFromAnnotations(req, scheduleAnnotationKey)
if err != nil {
- log.Debugf("No schedules to notify in %s", scheduleAnnotationKey)
+ log.DebugContext(ctx, "No schedules to notify", "schedule", scheduleAnnotationKey)
}
teamAnnotationKey := types.TeleportNamespace + types.ReqAnnotationTeamsLabel
teams, err = common.GetNamesFromAnnotations(req, teamAnnotationKey)
if err != nil {
- log.Debugf("No teams to notify in %s", teamAnnotationKey)
+ log.DebugContext(ctx, "No teams to notify", "teams", teamAnnotationKey)
}
if len(schedules) == 0 && len(teams) == 0 {
@@ -336,7 +346,7 @@ func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bo
recipientSchedules, recipientTeams, err := a.getMessageRecipients(ctx, req)
if err != nil {
- log.Debugf("Skipping the notification: %s", err)
+ log.DebugContext(ctx, "Skipping notification", "error", err)
return false, trace.Wrap(errMissingAnnotation)
}
@@ -434,8 +444,8 @@ func (a *App) createAlert(ctx context.Context, reqID string, reqData RequestData
if err != nil {
return trace.Wrap(err)
}
- ctx, log := logger.WithField(ctx, "opsgenie_alert_id", data.AlertID)
- log.Info("Successfully created Opsgenie alert")
+ ctx, log := logger.With(ctx, "opsgenie_alert_id", data.AlertID)
+ log.InfoContext(ctx, "Successfully created Opsgenie alert")
// Save opsgenie alert info in plugin data.
_, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) {
@@ -479,10 +489,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty
return trace.Wrap(err)
}
if !ok {
- logger.Get(ctx).Debug("Failed to post the note: plugin data is missing")
+ logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing")
return nil
}
- ctx, _ = logger.WithField(ctx, "opsgenie_alert_id", data.AlertID)
+ ctx, _ = logger.With(ctx, "opsgenie_alert_id", data.AlertID)
slice := reqReviews[oldCount:]
if len(slice) == 0 {
@@ -504,7 +514,7 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er
serviceNames, err := a.getOnCallServiceNames(req)
if err != nil {
- logger.Get(ctx).Debugf("Skipping the approval: %s", err)
+ logger.Get(ctx).DebugContext(ctx, "Skipping approval", "error", err)
return nil
}
@@ -537,14 +547,14 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er
},
}); err != nil {
if strings.HasSuffix(err.Error(), "has already reviewed this request") {
- log.Debug("Already reviewed the request")
+ log.DebugContext(ctx, "Already reviewed the request")
return nil
}
return trace.Wrap(err, "submitting access request")
}
}
- log.Info("Successfully submitted a request approval")
+ log.InfoContext(ctx, "Successfully submitted a request approval")
return nil
}
@@ -576,15 +586,15 @@ func (a *App) resolveAlert(ctx context.Context, reqID string, resolution Resolut
return trace.Wrap(err)
}
if !ok {
- logger.Get(ctx).Debug("Failed to resolve the alert: plugin data is missing")
+ logger.Get(ctx).DebugContext(ctx, "Failed to resolve the alert: plugin data is missing")
return nil
}
- ctx, log := logger.WithField(ctx, "opsgenie_alert_id", alertID)
+ ctx, log := logger.With(ctx, "opsgenie_alert_id", alertID)
if err := a.opsgenie.ResolveAlert(ctx, alertID, resolution); err != nil {
return trace.Wrap(err)
}
- log.Info("Successfully resolved the alert")
+ log.InfoContext(ctx, "Successfully resolved the alert")
return nil
}
diff --git a/integrations/access/opsgenie/client.go b/integrations/access/opsgenie/client.go
index 2619c6ed6f7a9..2c8cdaec09a33 100644
--- a/integrations/access/opsgenie/client.go
+++ b/integrations/access/opsgenie/client.go
@@ -185,10 +185,10 @@ func (og Client) tryGetAlertRequestResult(ctx context.Context, reqID string) (Ge
for {
alertRequestResult, err := og.getAlertRequestResult(ctx, reqID)
if err == nil {
- logger.Get(ctx).Debugf("Got alert request result: %+v", alertRequestResult)
+ logger.Get(ctx).DebugContext(ctx, "Got alert request result", "alert_id", alertRequestResult.Data.AlertID)
return alertRequestResult, nil
}
- logger.Get(ctx).Debug("Failed to get alert request result:", err)
+ logger.Get(ctx).DebugContext(ctx, "Failed to get alert request result", "error", err)
if err := backoff.Do(ctx); err != nil {
return GetAlertRequestResult{}, trace.Wrap(err)
}
@@ -344,8 +344,10 @@ func (og Client) CheckHealth(ctx context.Context) error {
code = types.PluginStatusCode_OTHER_ERROR
}
if err := og.StatusSink.Emit(ctx, &types.PluginStatusV1{Code: code}); err != nil {
- logger.Get(resp.Request.Context()).WithError(err).
- WithField("code", resp.StatusCode()).Errorf("Error while emitting servicenow plugin status: %v", err)
+ logger.Get(resp.Request.Context()).ErrorContext(ctx, "Error while emitting servicenow plugin status",
+ "error", err,
+ "code", resp.StatusCode(),
+ )
}
}
diff --git a/integrations/access/opsgenie/testlib/fake_opsgenie.go b/integrations/access/opsgenie/testlib/fake_opsgenie.go
index 9b5e6252119d1..1c124e19a75fc 100644
--- a/integrations/access/opsgenie/testlib/fake_opsgenie.go
+++ b/integrations/access/opsgenie/testlib/fake_opsgenie.go
@@ -32,7 +32,6 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/opsgenie"
@@ -314,7 +313,7 @@ func (s *FakeOpsgenie) GetSchedule(scheduleName string) ([]opsgenie.Responder, b
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/access/pagerduty/app.go b/integrations/access/pagerduty/app.go
index 5eadcc5147cd0..2351c5d2d5f02 100644
--- a/integrations/access/pagerduty/app.go
+++ b/integrations/access/pagerduty/app.go
@@ -22,6 +22,7 @@ import (
"context"
"errors"
"fmt"
+ "log/slog"
"strings"
"time"
@@ -38,6 +39,7 @@ import (
"github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/teleport/integrations/lib/watcherjob"
"github.com/gravitational/teleport/lib/utils"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
const (
@@ -106,7 +108,6 @@ func (a *App) run(ctx context.Context) error {
var err error
log := logger.Get(ctx)
- log.Infof("Starting Teleport Access PagerDuty Plugin")
if err = a.init(ctx); err != nil {
return trace.Wrap(err)
@@ -146,9 +147,9 @@ func (a *App) run(ctx context.Context) error {
a.mainJob.SetReady(ok)
if ok {
- log.Info("Plugin is ready")
+ log.InfoContext(ctx, "Plugin is ready")
} else {
- log.Error("Plugin is not ready")
+ log.ErrorContext(ctx, "Plugin is not ready")
}
<-watcherJob.Done()
@@ -202,25 +203,25 @@ func (a *App) init(ctx context.Context) error {
return trace.Wrap(err)
}
- log.Debug("Starting PagerDuty API health check...")
+ log.DebugContext(ctx, "Starting PagerDuty API health check")
if err = a.pagerduty.HealthCheck(ctx); err != nil {
return trace.Wrap(err, "api health check failed. check your credentials and service_id settings")
}
- log.Debug("PagerDuty API health check finished ok")
+ log.DebugContext(ctx, "PagerDuty API health check finished ok")
return nil
}
func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) {
log := logger.Get(ctx)
- log.Debug("Checking Teleport server version")
+ log.DebugContext(ctx, "Checking Teleport server version")
pong, err := a.teleport.Ping(ctx)
if err != nil {
if trace.IsNotImplemented(err) {
return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion)
}
- log.Error("Unable to get Teleport server version")
+ log.ErrorContext(ctx, "Unable to get Teleport server version")
return pong, trace.Wrap(err)
}
err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion)
@@ -245,16 +246,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
}
op := event.Type
reqID := event.Resource.GetName()
- ctx, _ = logger.WithField(ctx, "request_id", reqID)
+ ctx, _ = logger.With(ctx, "request_id", reqID)
switch op {
case types.OpPut:
- ctx, _ = logger.WithField(ctx, "request_op", "put")
+ ctx, _ = logger.With(ctx, "request_op", "put")
req, ok := event.Resource.(types.AccessRequest)
if !ok {
return trace.Errorf("unexpected resource type %T", event.Resource)
}
- ctx, log := logger.WithField(ctx, "request_state", req.GetState().String())
+ ctx, log := logger.With(ctx, "request_state", req.GetState().String())
var err error
switch {
@@ -263,21 +264,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
case req.GetState().IsResolved():
err = a.onResolvedRequest(ctx, req)
default:
- log.WithField("event", event).Warn("Unknown request state")
+ log.WarnContext(ctx, "Unknown request state",
+ slog.Group("event",
+ slog.Any("type", logutils.StringerAttr(event.Type)),
+ slog.Group("resource",
+ "kind", event.Resource.GetKind(),
+ "name", event.Resource.GetName(),
+ ),
+ ),
+ )
return nil
}
if err != nil {
- log.WithError(err).Error("Failed to process request")
+ log.ErrorContext(ctx, "Failed to process request", "error", err)
return trace.Wrap(err)
}
return nil
case types.OpDelete:
- ctx, log := logger.WithField(ctx, "request_op", "delete")
+ ctx, log := logger.With(ctx, "request_op", "delete")
if err := a.onDeletedRequest(ctx, reqID); err != nil {
- log.WithError(err).Error("Failed to process deleted request")
+ log.ErrorContext(ctx, "Failed to process deleted request", "error", err)
return trace.Wrap(err)
}
return nil
@@ -288,7 +297,7 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) error {
if len(req.GetSystemAnnotations()) == 0 {
- logger.Get(ctx).Debug("Cannot proceed further. Request is missing any annotations")
+ logger.Get(ctx).DebugContext(ctx, "Cannot proceed further - request is missing any annotations")
return nil
}
@@ -370,11 +379,11 @@ func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bo
serviceName, err := a.getNotifyServiceName(ctx, req)
if err != nil {
- log.Debugf("Skipping the notification: %s", err)
+ log.DebugContext(ctx, "Skipping the notification", "error", err)
return false, trace.Wrap(errSkip)
}
- ctx, _ = logger.WithField(ctx, "pd_service_name", serviceName)
+ ctx, _ = logger.With(ctx, "pd_service_name", serviceName)
service, err := a.pagerduty.FindServiceByName(ctx, serviceName)
if err != nil {
return false, trace.Wrap(err, "finding pagerduty service %s", serviceName)
@@ -420,8 +429,8 @@ func (a *App) createIncident(ctx context.Context, serviceID, reqID string, reqDa
if err != nil {
return trace.Wrap(err)
}
- ctx, log := logger.WithField(ctx, "pd_incident_id", data.IncidentID)
- log.Info("Successfully created PagerDuty incident")
+ ctx, log := logger.With(ctx, "pd_incident_id", data.IncidentID)
+ log.InfoContext(ctx, "Successfully created PagerDuty incident")
// Save pagerduty incident info in plugin data.
_, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) {
@@ -465,10 +474,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty
return trace.Wrap(err)
}
if !ok {
- logger.Get(ctx).Debug("Failed to post the note: plugin data is missing")
+ logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing")
return nil
}
- ctx, _ = logger.WithField(ctx, "pd_incident_id", data.IncidentID)
+ ctx, _ = logger.With(ctx, "pd_incident_id", data.IncidentID)
slice := reqReviews[oldCount:]
if len(slice) == 0 {
@@ -490,36 +499,40 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er
serviceNames, err := a.getOnCallServiceNames(req)
if err != nil {
- logger.Get(ctx).Debugf("Skipping the approval: %s", err)
+ logger.Get(ctx).DebugContext(ctx, "Skipping approval", "error", err)
return nil
}
userName := req.GetUser()
if !lib.IsEmail(userName) {
- logger.Get(ctx).Warningf("Skipping the approval: %q does not look like a valid email", userName)
+ logger.Get(ctx).WarnContext(ctx, "Skipping approval, found invalid email", "pd_user_email", userName)
return nil
}
user, err := a.pagerduty.FindUserByEmail(ctx, userName)
if err != nil {
if trace.IsNotFound(err) {
- log.WithError(err).WithField("pd_user_email", userName).Debug("Skipping the approval: email is not found")
+ log.DebugContext(ctx, "Skipping approval, email is not found",
+ "error", err,
+ "pd_user_email", userName)
return nil
}
return trace.Wrap(err)
}
- ctx, log = logger.WithFields(ctx, logger.Fields{
- "pd_user_email": user.Email,
- "pd_user_name": user.Name,
- })
+ ctx, log = logger.With(ctx,
+ "pd_user_email", user.Email,
+ "pd_user_name", user.Name,
+ )
services, err := a.pagerduty.FindServicesByNames(ctx, serviceNames)
if err != nil {
return trace.Wrap(err)
}
if len(services) == 0 {
- log.WithField("pd_service_names", serviceNames).Warning("Failed to find any service")
+ log.WarnContext(ctx, "Failed to find any service",
+ "pd_service_names", serviceNames,
+ )
return nil
}
@@ -536,7 +549,7 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er
return trace.Wrap(err)
}
if len(escalationPolicyIDs) == 0 {
- log.Debug("Skipping the approval: user is not on call")
+ log.DebugContext(ctx, "Skipping the approval: user is not on call")
return nil
}
@@ -561,13 +574,13 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er
},
}); err != nil {
if strings.HasSuffix(err.Error(), "has already reviewed this request") {
- log.Debug("Already reviewed the request")
+ log.DebugContext(ctx, "Already reviewed the request")
return nil
}
return trace.Wrap(err, "submitting access request")
}
- log.Info("Successfully submitted a request approval")
+ log.InfoContext(ctx, "Successfully submitted a request approval")
return nil
}
@@ -599,15 +612,15 @@ func (a *App) resolveIncident(ctx context.Context, reqID string, resolution Reso
return trace.Wrap(err)
}
if !ok {
- logger.Get(ctx).Debug("Failed to resolve the incident: plugin data is missing")
+ logger.Get(ctx).DebugContext(ctx, "Failed to resolve the incident: plugin data is missing")
return nil
}
- ctx, log := logger.WithField(ctx, "pd_incident_id", incidentID)
+ ctx, log := logger.With(ctx, "pd_incident_id", incidentID)
if err := a.pagerduty.ResolveIncident(ctx, incidentID, resolution); err != nil {
return trace.Wrap(err)
}
- log.Info("Successfully resolved the incident")
+ log.InfoContext(ctx, "Successfully resolved the incident")
return nil
}
diff --git a/integrations/access/pagerduty/client.go b/integrations/access/pagerduty/client.go
index 51adfb38f5aed..fd42876a154ca 100644
--- a/integrations/access/pagerduty/client.go
+++ b/integrations/access/pagerduty/client.go
@@ -122,7 +122,7 @@ func onAfterPagerDutyResponse(sink common.StatusSink) resty.ResponseMiddleware {
defer cancel()
if err := sink.Emit(ctx, status); err != nil {
- log.WithError(err).Errorf("Error while emitting PagerDuty plugin status: %v", err)
+ log.ErrorContext(ctx, "Error while emitting PagerDuty plugin status", "error", err)
}
if resp.IsError() {
@@ -288,7 +288,7 @@ func (p *Pagerduty) FindUserByEmail(ctx context.Context, userEmail string) (User
}
if len(result.Users) > 0 && result.More {
- logger.Get(ctx).Warningf("PagerDuty returned too many results when querying by email %q", userEmail)
+ logger.Get(ctx).WarnContext(ctx, "PagerDuty returned too many results when querying user email", "email", userEmail)
}
return User{}, trace.NotFound("failed to find pagerduty user by email %s", userEmail)
@@ -387,10 +387,10 @@ func (p *Pagerduty) FilterOnCallPolicies(ctx context.Context, userID string, esc
if len(filteredIDSet) == 0 {
if anyData {
- logger.Get(ctx).WithFields(logger.Fields{
- "pd_user_id": userID,
- "pd_escalation_policy_ids": escalationPolicyIDs,
- }).Warningf("PagerDuty returned some oncalls array but none of them matched the query")
+ logger.Get(ctx).WarnContext(ctx, "PagerDuty returned some oncalls array but none of them matched the query",
+ "pd_user_id", userID,
+ "pd_escalation_policy_ids", escalationPolicyIDs,
+ )
}
return nil, nil
diff --git a/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go b/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go
index aa4a8ba96eb32..58cfa27248d56 100644
--- a/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go
+++ b/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go
@@ -20,6 +20,7 @@ import (
"context"
_ "embed"
"fmt"
+ "log/slog"
"os"
"github.com/alecthomas/kingpin/v2"
@@ -65,12 +66,13 @@ func main() {
if err := run(*path, *debug); err != nil {
lib.Bail(err)
} else {
- logger.Standard().Info("Successfully shut down")
+ slog.InfoContext(context.Background(), "Successfully shut down")
}
}
}
func run(configPath string, debug bool) error {
+ ctx := context.Background()
conf, err := pagerduty.LoadConfig(configPath)
if err != nil {
return trace.Wrap(err)
@@ -84,7 +86,7 @@ func run(configPath string, debug bool) error {
return err
}
if debug {
- logger.Standard().Debugf("DEBUG logging enabled")
+ slog.DebugContext(ctx, "DEBUG logging enabled")
}
app, err := pagerduty.NewApp(*conf)
@@ -94,8 +96,9 @@ func run(configPath string, debug bool) error {
go lib.ServeSignals(app, common.PluginShutdownTimeout)
- logger.Standard().Infof("Starting Teleport Access PagerDuty Plugin %s:%s", teleport.Version, teleport.Gitref)
- return trace.Wrap(
- app.Run(context.Background()),
+ slog.InfoContext(ctx, "Starting Teleport Access PagerDuty Plugin",
+ "version", teleport.Version,
+ "git_ref", teleport.Gitref,
)
+ return trace.Wrap(app.Run(ctx))
}
diff --git a/integrations/access/pagerduty/testlib/fake_pagerduty.go b/integrations/access/pagerduty/testlib/fake_pagerduty.go
index 18a2a6ae24361..eee358f022458 100644
--- a/integrations/access/pagerduty/testlib/fake_pagerduty.go
+++ b/integrations/access/pagerduty/testlib/fake_pagerduty.go
@@ -32,7 +32,6 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/pagerduty"
"github.com/gravitational/teleport/integrations/lib/stringset"
@@ -565,6 +564,6 @@ func (s *FakePagerduty) CheckNewIncidentNote(ctx context.Context) (FakeIncidentN
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/access/servicenow/app.go b/integrations/access/servicenow/app.go
index 3d56f4fc97a8b..07248b488d872 100644
--- a/integrations/access/servicenow/app.go
+++ b/integrations/access/servicenow/app.go
@@ -21,6 +21,7 @@ package servicenow
import (
"context"
"fmt"
+ "log/slog"
"net/url"
"slices"
"strings"
@@ -41,6 +42,7 @@ import (
"github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/teleport/integrations/lib/watcherjob"
"github.com/gravitational/teleport/lib/utils"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
const (
@@ -116,7 +118,7 @@ func (a *App) WaitReady(ctx context.Context) (bool, error) {
func (a *App) run(ctx context.Context) error {
log := logger.Get(ctx)
- log.Infof("Starting Teleport Access Servicenow Plugin")
+ log.InfoContext(ctx, "Starting Teleport Access Servicenow Plugin")
if err := a.init(ctx); err != nil {
return trace.Wrap(err)
@@ -153,9 +155,9 @@ func (a *App) run(ctx context.Context) error {
}
a.mainJob.SetReady(ok)
if ok {
- log.Info("ServiceNow plugin is ready")
+ log.InfoContext(ctx, "ServiceNow plugin is ready")
} else {
- log.Error("ServiceNow plugin is not ready")
+ log.ErrorContext(ctx, "ServiceNow plugin is not ready")
}
<-watcherJob.Done()
@@ -190,25 +192,25 @@ func (a *App) init(ctx context.Context) error {
return trace.Wrap(err)
}
- log.Debug("Starting API health check...")
+ log.DebugContext(ctx, "Starting API health check")
if err = a.serviceNow.CheckHealth(ctx); err != nil {
return trace.Wrap(err, "API health check failed")
}
- log.Debug("API health check finished ok")
+ log.DebugContext(ctx, "API health check finished ok")
return nil
}
func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) {
log := logger.Get(ctx)
- log.Debug("Checking Teleport server version")
+ log.DebugContext(ctx, "Checking Teleport server version")
pong, err := a.teleport.Ping(ctx)
if err != nil {
if trace.IsNotImplemented(err) {
return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion)
}
- log.Error("Unable to get Teleport server version")
+ log.ErrorContext(ctx, "Unable to get Teleport server version")
return pong, trace.Wrap(err)
}
err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion)
@@ -233,16 +235,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
}
op := event.Type
reqID := event.Resource.GetName()
- ctx, _ = logger.WithField(ctx, "request_id", reqID)
+ ctx, _ = logger.With(ctx, "request_id", reqID)
switch op {
case types.OpPut:
- ctx, _ = logger.WithField(ctx, "request_op", "put")
+ ctx, _ = logger.With(ctx, "request_op", "put")
req, ok := event.Resource.(types.AccessRequest)
if !ok {
return trace.Errorf("unexpected resource type %T", event.Resource)
}
- ctx, log := logger.WithField(ctx, "request_state", req.GetState().String())
+ ctx, log := logger.With(ctx, "request_state", req.GetState().String())
var err error
switch {
@@ -251,21 +253,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
case req.GetState().IsResolved():
err = a.onResolvedRequest(ctx, req)
default:
- log.WithField("event", event).Warnf("Unknown request state: %q", req.GetState())
+ log.WarnContext(ctx, "Unknown request state",
+ slog.Group("event",
+ slog.Any("type", logutils.StringerAttr(event.Type)),
+ slog.Group("resource",
+ "kind", event.Resource.GetKind(),
+ "name", event.Resource.GetName(),
+ ),
+ ),
+ )
return nil
}
if err != nil {
- log.WithError(err).Error("Failed to process request")
+ log.ErrorContext(ctx, "Failed to process request", "error", err)
return trace.Wrap(err)
}
return nil
case types.OpDelete:
- ctx, log := logger.WithField(ctx, "request_op", "delete")
+ ctx, log := logger.With(ctx, "request_op", "delete")
if err := a.onDeletedRequest(ctx, reqID); err != nil {
- log.WithError(err).Error("Failed to process deleted request")
+ log.ErrorContext(ctx, "Failed to process deleted request", "error", err)
return trace.Wrap(err)
}
return nil
@@ -276,7 +286,7 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error
func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) error {
reqID := req.GetName()
- log := logger.Get(ctx).WithField("reqId", reqID)
+ log := logger.Get(ctx).With("req_id", reqID)
resourceNames, err := a.getResourceNames(ctx, req)
if err != nil {
@@ -303,7 +313,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err
}
if isNew {
- log.Infof("Creating servicenow incident")
+ log.InfoContext(ctx, "Creating servicenow incident")
recipientAssignee := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req)
assignees := []string{}
recipientAssignee.ForEach(func(r common.Recipient) {
@@ -375,8 +385,8 @@ func (a *App) createIncident(ctx context.Context, reqID string, reqData RequestD
if err != nil {
return trace.Wrap(err)
}
- ctx, log := logger.WithField(ctx, "servicenow_incident_id", data.IncidentID)
- log.Info("Successfully created Servicenow incident")
+ ctx, log := logger.With(ctx, "servicenow_incident_id", data.IncidentID)
+ log.InfoContext(ctx, "Successfully created Servicenow incident")
// Save servicenow incident info in plugin data.
_, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) {
@@ -420,10 +430,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty
return trace.Wrap(err)
}
if !ok {
- logger.Get(ctx).Debug("Failed to post the note: plugin data is missing")
+ logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing")
return nil
}
- ctx, _ = logger.WithField(ctx, "servicenow_incident_id", data.IncidentID)
+ ctx, _ = logger.With(ctx, "servicenow_incident_id", data.IncidentID)
slice := reqReviews[oldCount:]
if len(slice) == 0 {
@@ -445,22 +455,28 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er
serviceNames, err := a.getOnCallServiceNames(req)
if err != nil {
- logger.Get(ctx).Debugf("Skipping the approval: %s", err)
+ logger.Get(ctx).DebugContext(ctx, "Skipping the approval", "error", err)
return nil
}
- log.Debugf("Checking the following shifts to see if the requester is on-call: %s", serviceNames)
+ log.DebugContext(ctx, "Checking the shifts to see if the requester is on-call", "shifts", serviceNames)
onCallUsers, err := a.getOnCallUsers(ctx, serviceNames)
if err != nil {
return trace.Wrap(err)
}
- log.Debugf("Users on-call are: %s", onCallUsers)
+ log.DebugContext(ctx, "Users on-call are", "on_call_users", onCallUsers)
if userIsOnCall := slices.Contains(onCallUsers, req.GetUser()); !userIsOnCall {
- log.Debugf("User %q is not on-call, not approving the request %q.", req.GetUser(), req.GetName())
+ log.DebugContext(ctx, "User is not on-call, not approving the request",
+ "user", req.GetUser(),
+ "request", req.GetName(),
+ )
return nil
}
- log.Debugf("User %q is on-call. Auto-approving the request %q.", req.GetUser(), req.GetName())
+ log.DebugContext(ctx, "User is on-call, auto-approving the request",
+ "user", req.GetUser(),
+ "request", req.GetName(),
+ )
if _, err := a.teleport.SubmitAccessReview(ctx, types.AccessReviewSubmission{
RequestID: req.GetName(),
Review: types.AccessReview{
@@ -474,12 +490,12 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er
},
}); err != nil {
if strings.HasSuffix(err.Error(), "has already reviewed this request") {
- log.Debug("Already reviewed the request")
+ log.DebugContext(ctx, "Already reviewed the request")
return nil
}
return trace.Wrap(err, "submitting access request")
}
- log.Info("Successfully submitted a request approval")
+ log.InfoContext(ctx, "Successfully submitted a request approval")
return nil
}
@@ -490,7 +506,7 @@ func (a *App) getOnCallUsers(ctx context.Context, serviceNames []string) ([]stri
respondersResult, err := a.serviceNow.GetOnCall(ctx, scheduleName)
if err != nil {
if trace.IsNotFound(err) {
- log.WithError(err).Error("Failed to retrieve responder from schedule")
+ log.ErrorContext(ctx, "Failed to retrieve responder from schedule", "error", err)
continue
}
return nil, trace.Wrap(err)
@@ -528,15 +544,15 @@ func (a *App) resolveIncident(ctx context.Context, reqID string, resolution Reso
return trace.Wrap(err)
}
if !ok {
- logger.Get(ctx).Debug("Failed to resolve the incident: plugin data is missing")
+ logger.Get(ctx).DebugContext(ctx, "Failed to resolve the incident: plugin data is missing")
return nil
}
- ctx, log := logger.WithField(ctx, "servicenow_incident_id", incidentID)
+ ctx, log := logger.With(ctx, "servicenow_incident_id", incidentID)
if err := a.serviceNow.ResolveIncident(ctx, incidentID, resolution); err != nil {
return trace.Wrap(err)
}
- log.Info("Successfully resolved the incident")
+ log.InfoContext(ctx, "Successfully resolved the incident")
return nil
}
diff --git a/integrations/access/servicenow/client.go b/integrations/access/servicenow/client.go
index 8d0fb4f62b9de..8c306c1efa4ee 100644
--- a/integrations/access/servicenow/client.go
+++ b/integrations/access/servicenow/client.go
@@ -287,7 +287,10 @@ func (snc *Client) CheckHealth(ctx context.Context) error {
}
if err := snc.StatusSink.Emit(ctx, &types.PluginStatusV1{Code: code}); err != nil {
log := logger.Get(resp.Request.Context())
- log.WithError(err).WithField("code", resp.StatusCode()).Errorf("Error while emitting servicenow plugin status: %v", err)
+ log.ErrorContext(ctx, "Error while emitting servicenow plugin status",
+ "error", err,
+ "code", resp.StatusCode(),
+ )
}
}
diff --git a/integrations/access/servicenow/testlib/fake_servicenow.go b/integrations/access/servicenow/testlib/fake_servicenow.go
index 3b2d70e82a9b2..edf3fdced5fe7 100644
--- a/integrations/access/servicenow/testlib/fake_servicenow.go
+++ b/integrations/access/servicenow/testlib/fake_servicenow.go
@@ -32,7 +32,6 @@ import (
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/servicenow"
"github.com/gravitational/teleport/integrations/lib/stringset"
@@ -284,6 +283,6 @@ func (s *FakeServiceNow) getOnCall(rotationName string) []string {
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/access/slack/bot.go b/integrations/access/slack/bot.go
index 9c58093cb9897..e7fefa0107163 100644
--- a/integrations/access/slack/bot.go
+++ b/integrations/access/slack/bot.go
@@ -29,7 +29,6 @@ import (
"github.com/go-resty/resty/v2"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
@@ -37,6 +36,7 @@ import (
"github.com/gravitational/teleport/integrations/access/accessrequest"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/lib"
+ "github.com/gravitational/teleport/integrations/lib/logger"
pd "github.com/gravitational/teleport/integrations/lib/plugindata"
)
@@ -68,7 +68,7 @@ func onAfterResponseSlack(sink common.StatusSink) func(_ *resty.Client, resp *re
ctx, cancel := context.WithTimeout(context.Background(), statusEmitTimeout)
defer cancel()
if err := sink.Emit(ctx, status); err != nil {
- log.Errorf("Error while emitting plugin status: %v", err)
+ logger.Get(ctx).ErrorContext(ctx, "Error while emitting plugin status", "error", err)
}
}()
@@ -139,7 +139,7 @@ func (b Bot) BroadcastAccessRequestMessage(ctx context.Context, recipients []com
// the case with most SSO setups.
userRecipient, err := b.FetchRecipient(ctx, reqData.User)
if err != nil {
- log.Warningf("Unable to find user %s in Slack, will not be able to notify.", reqData.User)
+ logger.Get(ctx).WarnContext(ctx, "Unable to find user in Slack, will not be able to notify", "user", reqData.User)
}
// Include the user in the list of recipients if it exists.
diff --git a/integrations/access/slack/cmd/teleport-slack/main.go b/integrations/access/slack/cmd/teleport-slack/main.go
index 1f77db5f21492..ffa73144f540b 100644
--- a/integrations/access/slack/cmd/teleport-slack/main.go
+++ b/integrations/access/slack/cmd/teleport-slack/main.go
@@ -20,6 +20,7 @@ import (
"context"
_ "embed"
"fmt"
+ "log/slog"
"os"
"github.com/alecthomas/kingpin/v2"
@@ -65,12 +66,13 @@ func main() {
if err := run(*path, *debug); err != nil {
lib.Bail(err)
} else {
- logger.Standard().Info("Successfully shut down")
+ slog.InfoContext(context.Background(), "Successfully shut down")
}
}
}
func run(configPath string, debug bool) error {
+ ctx := context.Background()
conf, err := slack.LoadSlackConfig(configPath)
if err != nil {
return trace.Wrap(err)
@@ -84,14 +86,15 @@ func run(configPath string, debug bool) error {
return trace.Wrap(err)
}
if debug {
- logger.Standard().Debugf("DEBUG logging enabled")
+ slog.DebugContext(ctx, "DEBUG logging enabled")
}
app := slack.NewSlackApp(conf)
go lib.ServeSignals(app, common.PluginShutdownTimeout)
- logger.Standard().Infof("Starting Teleport Access Slack Plugin %s:%s", teleport.Version, teleport.Gitref)
- return trace.Wrap(
- app.Run(context.Background()),
+ slog.InfoContext(ctx, "Starting Teleport Access Slack Plugin",
+ "version", teleport.Version,
+ "git_ref", teleport.Gitref,
)
+ return trace.Wrap(app.Run(ctx))
}
diff --git a/integrations/access/slack/testlib/fake_slack.go b/integrations/access/slack/testlib/fake_slack.go
index eef81460da7f1..d18a43230c744 100644
--- a/integrations/access/slack/testlib/fake_slack.go
+++ b/integrations/access/slack/testlib/fake_slack.go
@@ -31,7 +31,6 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/integrations/access/slack"
)
@@ -315,6 +314,6 @@ func (s *FakeSlack) CheckMessageUpdateByResponding(ctx context.Context) (slack.M
func panicIf(err error) {
if err != nil {
- log.Panicf("%v at %v", err, string(debug.Stack()))
+ panic(fmt.Sprintf("%v at %v", err, string(debug.Stack())))
}
}
diff --git a/integrations/event-handler/fake_fluentd_test.go b/integrations/event-handler/fake_fluentd_test.go
index ecf286569f12d..72a363468ba15 100644
--- a/integrations/event-handler/fake_fluentd_test.go
+++ b/integrations/event-handler/fake_fluentd_test.go
@@ -31,8 +31,6 @@ import (
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
-
- "github.com/gravitational/teleport/integrations/lib/logger"
)
type FakeFluentd struct {
@@ -150,7 +148,6 @@ func (f *FakeFluentd) GetURL() string {
func (f *FakeFluentd) Respond(w http.ResponseWriter, r *http.Request) {
req, err := io.ReadAll(r.Body)
if err != nil {
- logger.Standard().WithError(err).Error("FakeFluentd Respond() failed to read body")
fmt.Fprintln(w, "NOK")
return
}
diff --git a/integrations/event-handler/main.go b/integrations/event-handler/main.go
index 859f6544c1e06..693b5bb24e036 100644
--- a/integrations/event-handler/main.go
+++ b/integrations/event-handler/main.go
@@ -46,8 +46,6 @@ const (
)
func main() {
- // This initializes the legacy logrus logger. This has been kept in place
- // in case any of the dependencies are still using logrus.
logger.Init()
ctx := kong.Parse(
@@ -64,17 +62,13 @@ func main() {
Format: "text",
}
if cli.Debug {
- enableLogDebug()
logCfg.Severity = "debug"
}
- log, err := logCfg.NewSLogLogger()
- if err != nil {
- fmt.Println(trace.DebugReport(trace.Wrap(err, "initializing logger")))
+
+ if err := logger.Setup(logCfg); err != nil {
+ fmt.Println(trace.DebugReport(err))
os.Exit(-1)
}
- // Whilst this package mostly dependency injects slog, upstream dependencies
- // may still use the default slog logger.
- slog.SetDefault(log)
switch {
case ctx.Command() == "version":
@@ -86,25 +80,16 @@ func main() {
os.Exit(-1)
}
case ctx.Command() == "start":
- err := start(log)
+ err := start(slog.Default())
if err != nil {
lib.Bail(err)
} else {
- log.InfoContext(context.TODO(), "Successfully shut down")
+ slog.InfoContext(context.TODO(), "Successfully shut down")
}
}
}
-// turn on log debugging
-func enableLogDebug() {
- err := logger.Setup(logger.Config{Severity: "debug", Output: "stderr"})
- if err != nil {
- fmt.Println(trace.DebugReport(err))
- os.Exit(-1)
- }
-}
-
// start spawns the main process
func start(log *slog.Logger) error {
app, err := NewApp(&cli.Start, log)
diff --git a/integrations/lib/bail.go b/integrations/lib/bail.go
index 72804cd0ac3c4..d1351bb05f7fe 100644
--- a/integrations/lib/bail.go
+++ b/integrations/lib/bail.go
@@ -19,22 +19,24 @@
package lib
import (
+ "context"
"errors"
+ "log/slog"
"os"
"github.com/gravitational/trace"
- log "github.com/sirupsen/logrus"
)
// Bail exits with nonzero exit code and prints an error to a log.
func Bail(err error) {
+ ctx := context.Background()
var agg trace.Aggregate
if errors.As(trace.Unwrap(err), &agg) {
for i, err := range agg.Errors() {
- log.WithError(err).Errorf("Terminating with fatal error [%d]...", i+1)
+ slog.ErrorContext(ctx, "Terminating with fatal error", "error_number", i+1, "error", err)
}
} else {
- log.WithError(err).Error("Terminating with fatal error...")
+ slog.ErrorContext(ctx, "Terminating with fatal error", "error", err)
}
os.Exit(1)
}
diff --git a/integrations/lib/config.go b/integrations/lib/config.go
index 24f6c981e6686..66285167e5e36 100644
--- a/integrations/lib/config.go
+++ b/integrations/lib/config.go
@@ -22,12 +22,12 @@ import (
"context"
"errors"
"io"
+ "log/slog"
"os"
"strings"
"time"
"github.com/gravitational/trace"
- log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
grpcbackoff "google.golang.org/grpc/backoff"
@@ -137,7 +137,7 @@ func NewIdentityFileWatcher(ctx context.Context, path string, interval time.Dura
}
if err := dynamicCred.Reload(); err != nil {
- log.WithError(err).Error("Failed to reload identity file from disk.")
+ slog.ErrorContext(ctx, "Failed to reload identity file from disk", "error", err)
}
timer.Reset(interval)
}
@@ -152,7 +152,7 @@ func (cfg TeleportConfig) NewClient(ctx context.Context) (*client.Client, error)
case cfg.Addr != "":
addr = cfg.Addr
case cfg.AuthServer != "":
- log.Warn("Configuration setting `auth_server` is deprecated, consider to change it to `addr`")
+ slog.WarnContext(ctx, "Configuration setting `auth_server` is deprecated, consider to change it to `addr`")
addr = cfg.AuthServer
}
@@ -173,13 +173,13 @@ func (cfg TeleportConfig) NewClient(ctx context.Context) (*client.Client, error)
}
if validCred, err := credentials.CheckIfExpired(creds); err != nil {
- log.Warn(err)
+ slog.WarnContext(ctx, "found expired credentials", "error", err)
if !validCred {
return nil, trace.BadParameter(
"No valid credentials found, this likely means credentials are expired. In this case, please sign new credentials and increase their TTL if needed.",
)
}
- log.Info("At least one non-expired credential has been found, continuing startup")
+ slog.InfoContext(ctx, "At least one non-expired credential has been found, continuing startup")
}
bk := grpcbackoff.DefaultConfig
diff --git a/integrations/lib/embeddedtbot/bot.go b/integrations/lib/embeddedtbot/bot.go
index e693b40793fe5..b8ed026386114 100644
--- a/integrations/lib/embeddedtbot/bot.go
+++ b/integrations/lib/embeddedtbot/bot.go
@@ -26,7 +26,6 @@ import (
"time"
"github.com/gravitational/trace"
- log "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
@@ -106,9 +105,9 @@ func (b *EmbeddedBot) start(ctx context.Context) {
go func() {
err := bot.Run(botCtx)
if err != nil {
- log.Errorf("bot exited with error: %s", err)
+ slog.ErrorContext(botCtx, "bot exited with error", "error", err)
} else {
- log.Infof("bot exited without error")
+ slog.InfoContext(botCtx, "bot exited without error")
}
b.errCh <- trace.Wrap(err)
}()
@@ -142,10 +141,10 @@ func (b *EmbeddedBot) waitForCredentials(ctx context.Context, deadline time.Dura
select {
case <-waitCtx.Done():
- log.Warn("context canceled while waiting for the bot client")
+ slog.WarnContext(ctx, "context canceled while waiting for the bot client")
return nil, trace.Wrap(ctx.Err())
case <-b.credential.Ready():
- log.Infof("credential ready")
+ slog.InfoContext(ctx, "credential ready")
}
return b.credential, nil
@@ -177,7 +176,7 @@ func (b *EmbeddedBot) StartAndWaitForCredentials(ctx context.Context, deadline t
// buildClient reads tbot's memory disttination, retrieves the certificates
// and builds a new Teleport client using those certs.
func (b *EmbeddedBot) buildClient(ctx context.Context) (*client.Client, error) {
- log.Infof("Building a new client to connect to %s", b.cfg.AuthServer)
+ slog.InfoContext(ctx, "Building a new client to connect to cluster", "auth_server_address", b.cfg.AuthServer)
c, err := client.New(ctx, client.Config{
Addrs: []string{b.cfg.AuthServer},
Credentials: []client.Credentials{b.credential},
diff --git a/integrations/lib/http.go b/integrations/lib/http.go
index dbb279913a5bd..6f98ad957a75c 100644
--- a/integrations/lib/http.go
+++ b/integrations/lib/http.go
@@ -24,6 +24,7 @@ import (
"crypto/x509"
"errors"
"fmt"
+ "log/slog"
"net"
"net/http"
"net/url"
@@ -33,7 +34,8 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- log "github.com/sirupsen/logrus"
+
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
// TLSConfig stores TLS configuration for a http service
@@ -178,7 +180,7 @@ func NewHTTP(config HTTPConfig) (*HTTP, error) {
if verify := config.TLS.VerifyClientCertificateFunc; verify != nil {
tlsConfig.VerifyPeerCertificate = func(_ [][]byte, chains [][]*x509.Certificate) error {
if err := verify(chains); err != nil {
- log.WithError(err).Error("HTTPS client certificate verification failed")
+ slog.ErrorContext(context.Background(), "HTTPS client certificate verification failed", "error", err)
return err
}
return nil
@@ -217,7 +219,7 @@ func BuildURLPath(args ...interface{}) string {
// ListenAndServe runs a http(s) server on a provided port.
func (h *HTTP) ListenAndServe(ctx context.Context) error {
- defer log.Debug("HTTP server terminated")
+ defer slog.DebugContext(ctx, "HTTP server terminated")
var err error
h.server.BaseContext = func(_ net.Listener) context.Context {
@@ -256,10 +258,10 @@ func (h *HTTP) ListenAndServe(ctx context.Context) error {
}
if h.Insecure {
- log.Debugf("Starting insecure HTTP server on %s", addr)
+ slog.DebugContext(ctx, "Starting insecure HTTP server", "listen_addr", logutils.StringerAttr(addr))
err = h.server.Serve(listener)
} else {
- log.Debugf("Starting secure HTTPS server on %s", addr)
+ slog.DebugContext(ctx, "Starting secure HTTPS server", "listen_addr", logutils.StringerAttr(addr))
err = h.server.ServeTLS(listener, h.CertFile, h.KeyFile)
}
if errors.Is(err, http.ErrServerClosed) {
@@ -288,7 +290,7 @@ func (h *HTTP) ServiceJob() ServiceJob {
return NewServiceJob(func(ctx context.Context) error {
MustGetProcess(ctx).OnTerminate(func(ctx context.Context) error {
if err := h.ShutdownWithTimeout(ctx, time.Second*5); err != nil {
- log.Error("HTTP server graceful shutdown failed")
+ slog.ErrorContext(ctx, "HTTP server graceful shutdown failed")
return err
}
return nil
diff --git a/integrations/lib/logger/logger.go b/integrations/lib/logger/logger.go
index 7422f03ff906c..a1ce5bf7275ed 100644
--- a/integrations/lib/logger/logger.go
+++ b/integrations/lib/logger/logger.go
@@ -20,16 +20,11 @@ package logger
import (
"context"
- "io"
- "io/fs"
"log/slog"
"os"
- "strings"
"github.com/gravitational/trace"
- log "github.com/sirupsen/logrus"
- "github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/utils"
logutils "github.com/gravitational/teleport/lib/utils/log"
)
@@ -41,8 +36,6 @@ type Config struct {
Format string `toml:"format"`
}
-type Fields = log.Fields
-
type contextKey struct{}
var extraFields = []string{logutils.LevelField, logutils.ComponentField, logutils.CallerField}
@@ -50,179 +43,50 @@ var extraFields = []string{logutils.LevelField, logutils.ComponentField, logutil
// Init sets up logger for a typical daemon scenario until configuration
// file is parsed
func Init() {
- formatter := &logutils.TextFormatter{
- EnableColors: utils.IsTerminal(os.Stderr),
- ComponentPadding: 1, // We don't use components so strip the padding
- ExtraFields: extraFields,
- }
-
- log.SetOutput(os.Stderr)
- if err := formatter.CheckAndSetDefaults(); err != nil {
- log.WithError(err).Error("unable to create text log formatter")
- return
- }
-
- log.SetFormatter(formatter)
+ enableColors := utils.IsTerminal(os.Stderr)
+ logutils.Initialize(logutils.Config{
+ Severity: slog.LevelInfo.String(),
+ Format: "text",
+ ExtraFields: extraFields,
+ EnableColors: enableColors,
+ Padding: 1,
+ })
}
func Setup(conf Config) error {
+ var enableColors bool
switch conf.Output {
case "stderr", "error", "2":
- log.SetOutput(os.Stderr)
+ enableColors = utils.IsTerminal(os.Stderr)
case "", "stdout", "out", "1":
- log.SetOutput(os.Stdout)
+ enableColors = utils.IsTerminal(os.Stdout)
default:
- // assume it's a file path:
- logFile, err := os.Create(conf.Output)
- if err != nil {
- return trace.Wrap(err, "failed to create the log file")
- }
- log.SetOutput(logFile)
}
- switch strings.ToLower(conf.Severity) {
- case "info":
- log.SetLevel(log.InfoLevel)
- case "err", "error":
- log.SetLevel(log.ErrorLevel)
- case "debug":
- log.SetLevel(log.DebugLevel)
- case "warn", "warning":
- log.SetLevel(log.WarnLevel)
- case "trace":
- log.SetLevel(log.TraceLevel)
- default:
- return trace.BadParameter("unsupported logger severity: '%v'", conf.Severity)
- }
-
- return nil
+ _, _, err := logutils.Initialize(logutils.Config{
+ Output: conf.Output,
+ Severity: conf.Severity,
+ Format: conf.Format,
+ ExtraFields: extraFields,
+ EnableColors: enableColors,
+ Padding: 1,
+ })
+ return trace.Wrap(err)
}
-// NewSLogLogger builds a slog.Logger from the logger.Config.
-// TODO(tross): Defer logging initialization to logutils.Initialize and use the
-// global slog loggers once integrations has been updated to use slog.
-func (conf Config) NewSLogLogger() (*slog.Logger, error) {
- const (
- // logFileDefaultMode is the preferred permissions mode for log file.
- logFileDefaultMode fs.FileMode = 0o644
- // logFileDefaultFlag is the preferred flags set to log file.
- logFileDefaultFlag = os.O_WRONLY | os.O_CREATE | os.O_APPEND
- )
-
- var w io.Writer
- switch conf.Output {
- case "":
- w = logutils.NewSharedWriter(os.Stderr)
- case "stderr", "error", "2":
- w = logutils.NewSharedWriter(os.Stderr)
- case "stdout", "out", "1":
- w = logutils.NewSharedWriter(os.Stdout)
- case teleport.Syslog:
- w = os.Stderr
- sw, err := logutils.NewSyslogWriter()
- if err != nil {
- slog.Default().ErrorContext(context.Background(), "Failed to switch logging to syslog", "error", err)
- break
- }
-
- // If syslog output has been configured and is supported by the operating system,
- // then the shared writer is not needed because the syslog writer is already
- // protected with a mutex.
- w = sw
- default:
- // Assume this is a file path.
- sharedWriter, err := logutils.NewFileSharedWriter(conf.Output, logFileDefaultFlag, logFileDefaultMode)
- if err != nil {
- return nil, trace.Wrap(err, "failed to init the log file shared writer")
- }
- w = logutils.NewWriterFinalizer[*logutils.FileSharedWriter](sharedWriter)
- if err := sharedWriter.RunWatcherReopen(context.Background()); err != nil {
- return nil, trace.Wrap(err)
- }
- }
-
- level := new(slog.LevelVar)
- switch strings.ToLower(conf.Severity) {
- case "", "info":
- level.Set(slog.LevelInfo)
- case "err", "error":
- level.Set(slog.LevelError)
- case teleport.DebugLevel:
- level.Set(slog.LevelDebug)
- case "warn", "warning":
- level.Set(slog.LevelWarn)
- case "trace":
- level.Set(logutils.TraceLevel)
- default:
- return nil, trace.BadParameter("unsupported logger severity: %q", conf.Severity)
- }
-
- configuredFields, err := logutils.ValidateFields(extraFields)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- var slogLogger *slog.Logger
- switch strings.ToLower(conf.Format) {
- case "":
- fallthrough // not set. defaults to 'text'
- case "text":
- enableColors := utils.IsTerminal(os.Stderr)
- slogLogger = slog.New(logutils.NewSlogTextHandler(w, logutils.SlogTextHandlerConfig{
- Level: level,
- EnableColors: enableColors,
- ConfiguredFields: configuredFields,
- }))
- slog.SetDefault(slogLogger)
- case "json":
- slogLogger = slog.New(logutils.NewSlogJSONHandler(w, logutils.SlogJSONHandlerConfig{
- Level: level,
- ConfiguredFields: configuredFields,
- }))
- slog.SetDefault(slogLogger)
- default:
- return nil, trace.BadParameter("unsupported log output format : %q", conf.Format)
- }
-
- return slogLogger, nil
-}
-
-func WithLogger(ctx context.Context, logger log.FieldLogger) context.Context {
- return withLogger(ctx, logger)
-}
-
-func withLogger(ctx context.Context, logger log.FieldLogger) context.Context {
+func WithLogger(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, contextKey{}, logger)
}
-func WithField(ctx context.Context, key string, value interface{}) (context.Context, log.FieldLogger) {
- logger := Get(ctx).WithField(key, value)
- return withLogger(ctx, logger), logger
+func With(ctx context.Context, args ...any) (context.Context, *slog.Logger) {
+ logger := Get(ctx).With(args...)
+ return WithLogger(ctx, logger), logger
}
-func WithFields(ctx context.Context, logFields Fields) (context.Context, log.FieldLogger) {
- logger := Get(ctx).WithFields(logFields)
- return withLogger(ctx, logger), logger
-}
-
-func SetField(ctx context.Context, key string, value interface{}) context.Context {
- ctx, _ = WithField(ctx, key, value)
- return ctx
-}
-
-func SetFields(ctx context.Context, logFields Fields) context.Context {
- ctx, _ = WithFields(ctx, logFields)
- return ctx
-}
-
-func Get(ctx context.Context) log.FieldLogger {
- if logger, ok := ctx.Value(contextKey{}).(log.FieldLogger); ok && logger != nil {
+func Get(ctx context.Context) *slog.Logger {
+ if logger, ok := ctx.Value(contextKey{}).(*slog.Logger); ok && logger != nil {
return logger
}
- return Standard()
-}
-
-func Standard() log.FieldLogger {
- return log.StandardLogger()
+ return slog.Default()
}
diff --git a/integrations/lib/signals.go b/integrations/lib/signals.go
index 4774915a6271b..4702455dfc7ca 100644
--- a/integrations/lib/signals.go
+++ b/integrations/lib/signals.go
@@ -20,12 +20,11 @@ package lib
import (
"context"
+ "log/slog"
"os"
"os/signal"
"syscall"
"time"
-
- log "github.com/sirupsen/logrus"
)
type Terminable interface {
@@ -48,9 +47,9 @@ func ServeSignals(app Terminable, shutdownTimeout time.Duration) {
gracefulShutdown := func() {
tctx, tcancel := context.WithTimeout(ctx, shutdownTimeout)
defer tcancel()
- log.Infof("Attempting graceful shutdown...")
+ slog.InfoContext(tctx, "Attempting graceful shutdown")
if err := app.Shutdown(tctx); err != nil {
- log.Infof("Graceful shutdown failed. Trying fast shutdown...")
+ slog.InfoContext(tctx, "Graceful shutdown failed, attempting fast shutdown")
app.Close()
}
}
diff --git a/integrations/lib/tctl/tctl.go b/integrations/lib/tctl/tctl.go
index 25e7e5e95e0da..5fa0a3252b45b 100644
--- a/integrations/lib/tctl/tctl.go
+++ b/integrations/lib/tctl/tctl.go
@@ -27,6 +27,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/lib/logger"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
var regexpStatusCAPin = regexp.MustCompile(`CA pin +(sha256:[a-zA-Z0-9]+)`)
@@ -59,10 +60,14 @@ func (tctl Tctl) Sign(ctx context.Context, username, format, outPath string) err
outPath,
)
cmd := exec.CommandContext(ctx, tctl.cmd(), args...)
- log.Debugf("Running %s", cmd)
+ log.DebugContext(ctx, "Running tctl auth sign", "command", logutils.StringerAttr(cmd))
output, err := cmd.CombinedOutput()
if err != nil {
- log.WithError(err).WithField("args", args).Debug("tctl auth sign failed:", string(output))
+ log.DebugContext(ctx, "tctl auth sign failed",
+ "error", err,
+ "args", args,
+ "command_output", string(output),
+ )
return trace.Wrap(err, "tctl auth sign failed")
}
return nil
@@ -73,7 +78,7 @@ func (tctl Tctl) Create(ctx context.Context, resources []types.Resource) error {
log := logger.Get(ctx)
args := append(tctl.baseArgs(), "create")
cmd := exec.CommandContext(ctx, tctl.cmd(), args...)
- log.Debugf("Running %s", cmd)
+ log.DebugContext(ctx, "Running tctl create", "command", logutils.StringerAttr(cmd))
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return trace.Wrap(err, "failed to get stdin pipe")
@@ -81,16 +86,19 @@ func (tctl Tctl) Create(ctx context.Context, resources []types.Resource) error {
go func() {
defer func() {
if err := stdinPipe.Close(); err != nil {
- log.WithError(trace.Wrap(err)).Error("Failed to close stdin pipe")
+ log.ErrorContext(ctx, "Failed to close stdin pipe", "error", err)
}
}()
if err := writeResourcesYAML(stdinPipe, resources); err != nil {
- log.WithError(trace.Wrap(err)).Error("Failed to serialize resources stdin")
+ log.ErrorContext(ctx, "Failed to serialize resources stdin", "error", err)
}
}()
output, err := cmd.CombinedOutput()
if err != nil {
- log.WithError(err).Debug("tctl create failed:", string(output))
+ log.DebugContext(ctx, "tctl create failed",
+ "error", err,
+ "command_output", string(output),
+ )
return trace.Wrap(err, "tctl create failed")
}
return nil
@@ -102,7 +110,7 @@ func (tctl Tctl) GetAll(ctx context.Context, query string) ([]types.Resource, er
args := append(tctl.baseArgs(), "get", query)
cmd := exec.CommandContext(ctx, tctl.cmd(), args...)
- log.Debugf("Running %s", cmd)
+ log.DebugContext(ctx, "Running tctl get", "command", logutils.StringerAttr(cmd))
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, trace.Wrap(err, "failed to get stdout")
@@ -140,7 +148,7 @@ func (tctl Tctl) GetCAPin(ctx context.Context) (string, error) {
args := append(tctl.baseArgs(), "status")
cmd := exec.CommandContext(ctx, tctl.cmd(), args...)
- log.Debugf("Running %s", cmd)
+ log.DebugContext(ctx, "Running tctl status", "command", logutils.StringerAttr(cmd))
output, err := cmd.Output()
if err != nil {
return "", trace.Wrap(err, "failed to get auth status")
diff --git a/integrations/lib/testing/integration/suite.go b/integrations/lib/testing/integration/suite.go
index 22c0754f66a3b..c0f03c647ef75 100644
--- a/integrations/lib/testing/integration/suite.go
+++ b/integrations/lib/testing/integration/suite.go
@@ -93,7 +93,7 @@ func (s *Suite) initContexts(oldT *testing.T, newT *testing.T) {
} else {
baseCtx = context.Background()
}
- baseCtx, _ = logger.WithField(baseCtx, "test", newT.Name())
+ baseCtx, _ = logger.With(baseCtx, "test", newT.Name())
baseCtx, cancel := context.WithCancel(baseCtx)
newT.Cleanup(cancel)
@@ -163,7 +163,7 @@ func (s *Suite) StartApp(app AppI) {
if err := app.Run(ctx); err != nil {
// We're in a goroutine so we can't just require.NoError(t, err).
// All we can do is to log an error.
- logger.Get(ctx).WithError(err).Error("Application failed")
+ logger.Get(ctx).ErrorContext(ctx, "Application failed", "error", err)
}
}()
diff --git a/integrations/lib/watcherjob/watcherjob.go b/integrations/lib/watcherjob/watcherjob.go
index 2999b86aaad0b..a7d2d14482ae6 100644
--- a/integrations/lib/watcherjob/watcherjob.go
+++ b/integrations/lib/watcherjob/watcherjob.go
@@ -130,23 +130,23 @@ func newJobWithEvents(events types.Events, config Config, fn EventFunc, watchIni
if config.FailFast {
return trace.WrapWithMessage(err, "Connection problem detected. Exiting as fail fast is on.")
}
- log.WithError(err).Error("Connection problem detected. Attempting to reconnect.")
+ log.ErrorContext(ctx, "Connection problem detected, attempting to reconnect", "error", err)
case errors.Is(err, io.EOF):
if config.FailFast {
return trace.WrapWithMessage(err, "Watcher stream closed. Exiting as fail fast is on.")
}
- log.WithError(err).Error("Watcher stream closed. Attempting to reconnect.")
+ log.ErrorContext(ctx, "Watcher stream closed attempting to reconnect", "error", err)
case lib.IsCanceled(err):
- log.Debug("Watcher context is canceled")
+ log.DebugContext(ctx, "Watcher context is canceled")
return trace.Wrap(err)
default:
- log.WithError(err).Error("Watcher event loop failed")
+ log.ErrorContext(ctx, "Watcher event loop failed", "error", err)
return trace.Wrap(err)
}
// To mitigate a potentially aggressive retry loop, we wait
if err := bk.Do(ctx); err != nil {
- log.Debug("Watcher context was canceled while waiting before a reconnection")
+ log.DebugContext(ctx, "Watcher context was canceled while waiting before a reconnection")
return trace.Wrap(err)
}
}
@@ -162,7 +162,7 @@ func (job job) watchEvents(ctx context.Context) error {
}
defer func() {
if err := watcher.Close(); err != nil {
- logger.Get(ctx).WithError(err).Error("Failed to close a watcher")
+ logger.Get(ctx).ErrorContext(ctx, "Failed to close a watcher", "error", err)
}
}()
@@ -170,7 +170,7 @@ func (job job) watchEvents(ctx context.Context) error {
return trace.Wrap(err)
}
- logger.Get(ctx).Debug("Watcher connected")
+ logger.Get(ctx).DebugContext(ctx, "Watcher connected")
job.SetReady(true)
for {
@@ -253,7 +253,7 @@ func (job job) eventLoop(ctx context.Context) error {
event := *eventPtr
resource := event.Resource
if resource == nil {
- log.Error("received an event with empty resource field")
+ log.ErrorContext(ctx, "received an event with empty resource field")
}
key := eventKey{kind: resource.GetKind(), name: resource.GetName()}
if queue, loaded := queues[key]; loaded {
diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go
index b1c7c7339c4ba..585c82058d5fb 100644
--- a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go
+++ b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go
@@ -21,38 +21,37 @@
package main
import (
+ "context"
+ "log/slog"
"os"
- "github.com/gravitational/trace"
- log "github.com/sirupsen/logrus"
-
crdgen "github.com/gravitational/teleport/integrations/operator/crdgen"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
func main() {
- log.SetLevel(log.DebugLevel)
- log.SetOutput(os.Stderr)
+ slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr,
+ logutils.SlogTextHandlerConfig{
+ Level: slog.LevelDebug,
+ },
+ )))
+ ctx := context.Background()
inputPath := os.Getenv(crdgen.PluginInputPathEnvironment)
if inputPath == "" {
- log.Error(
- trace.BadParameter(
- "When built with the 'debug' tag, the input path must be set through the environment variable: %s",
- crdgen.PluginInputPathEnvironment,
- ),
- )
+ slog.ErrorContext(ctx, "When built with the 'debug' tag, the input path must be set through the TELEPORT_PROTOC_READ_FILE environment variable")
os.Exit(-1)
}
- log.Infof("This is a debug build, the protoc request is read from the file: '%s'", inputPath)
+ slog.InfoContext(ctx, "This is a debug build, the protoc request is read from the file", "input_path", inputPath)
req, err := crdgen.ReadRequestFromFile(inputPath)
if err != nil {
- log.WithError(err).Error("error reading request from file")
+ slog.ErrorContext(ctx, "error reading request from file", "error", err)
os.Exit(-1)
}
if err := crdgen.HandleDocsRequest(req); err != nil {
- log.WithError(err).Error("Failed to generate docs")
+ slog.ErrorContext(ctx, "Failed to generate docs", "error", err)
os.Exit(-1)
}
}
diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go
index e091e5a8c1d0f..ac1be771b0bf0 100644
--- a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go
+++ b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go
@@ -21,20 +21,26 @@
package main
import (
+ "context"
+ "log/slog"
"os"
"github.com/gogo/protobuf/vanity/command"
- log "github.com/sirupsen/logrus"
crdgen "github.com/gravitational/teleport/integrations/operator/crdgen"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
func main() {
- log.SetLevel(log.DebugLevel)
- log.SetOutput(os.Stderr)
+ slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr,
+ logutils.SlogTextHandlerConfig{
+ Level: slog.LevelDebug,
+ },
+ )))
+
req := command.Read()
if err := crdgen.HandleDocsRequest(req); err != nil {
- log.WithError(err).Error("Failed to generate schema")
+ slog.ErrorContext(context.Background(), "Failed to generate schema", "error", err)
os.Exit(-1)
}
}
diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go b/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go
index bf19cf7eaca87..2da3e47ab9ec8 100644
--- a/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go
+++ b/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go
@@ -21,38 +21,37 @@
package main
import (
+ "context"
+ "log/slog"
"os"
- "github.com/gravitational/trace"
- log "github.com/sirupsen/logrus"
-
crdgen "github.com/gravitational/teleport/integrations/operator/crdgen"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
func main() {
- log.SetLevel(log.DebugLevel)
- log.SetOutput(os.Stderr)
+ slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr,
+ logutils.SlogTextHandlerConfig{
+ Level: slog.LevelDebug,
+ },
+ )))
+ ctx := context.Background()
inputPath := os.Getenv(crdgen.PluginInputPathEnvironment)
if inputPath == "" {
- log.Error(
- trace.BadParameter(
- "When built with the 'debug' tag, the input path must be set through the environment variable: %s",
- crdgen.PluginInputPathEnvironment,
- ),
- )
+ slog.ErrorContext(ctx, "When built with the 'debug' tag, the input path must be set through the TELEPORT_PROTOC_READ_FILE environment variable")
os.Exit(-1)
}
- log.Infof("This is a debug build, the protoc request is read from the file: '%s'", inputPath)
+ slog.InfoContext(ctx, "This is a debug build, the protoc request is read from the file", "input_path", inputPath)
req, err := crdgen.ReadRequestFromFile(inputPath)
if err != nil {
- log.WithError(err).Error("error reading request from file")
+ slog.ErrorContext(ctx, "error reading request from file", "error", err)
os.Exit(-1)
}
if err := crdgen.HandleCRDRequest(req); err != nil {
- log.WithError(err).Error("Failed to generate schema")
+ slog.ErrorContext(ctx, "Failed to generate schema", "error", err)
os.Exit(-1)
}
}
diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go b/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go
index 863af95862505..a557993626415 100644
--- a/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go
+++ b/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go
@@ -21,20 +21,26 @@
package main
import (
+ "context"
+ "log/slog"
"os"
"github.com/gogo/protobuf/vanity/command"
- log "github.com/sirupsen/logrus"
crdgen "github.com/gravitational/teleport/integrations/operator/crdgen"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
func main() {
- log.SetLevel(log.DebugLevel)
- log.SetOutput(os.Stderr)
+ slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr,
+ logutils.SlogTextHandlerConfig{
+ Level: slog.LevelDebug,
+ },
+ )))
+
req := command.Read()
if err := crdgen.HandleCRDRequest(req); err != nil {
- log.WithError(err).Error("Failed to generate schema")
+ slog.ErrorContext(context.Background(), "Failed to generate schema", "error", err)
os.Exit(-1)
}
}
diff --git a/integrations/terraform/Makefile b/integrations/terraform/Makefile
index 572a07d4d45dc..149aef0ed5b4b 100644
--- a/integrations/terraform/Makefile
+++ b/integrations/terraform/Makefile
@@ -47,7 +47,7 @@ $(BUILDDIR)/terraform-provider-teleport_%: terraform-provider-teleport-v$(VERSIO
CUSTOM_IMPORTS_TMP_DIR ?= /tmp/protoc-gen-terraform/custom-imports
# This version must match the version installed by .github/workflows/lint.yaml
-PROTOC_GEN_TERRAFORM_VERSION ?= v3.0.0
+PROTOC_GEN_TERRAFORM_VERSION ?= v3.0.2
PROTOC_GEN_TERRAFORM_EXISTS := $(shell $(PROTOC_GEN_TERRAFORM) version 2>&1 >/dev/null | grep 'protoc-gen-terraform $(PROTOC_GEN_TERRAFORM_VERSION)')
.PHONY: gen-tfschema
diff --git a/integrations/terraform/README.md b/integrations/terraform/README.md
index 53e752f725d41..dde74bc7b793b 100644
--- a/integrations/terraform/README.md
+++ b/integrations/terraform/README.md
@@ -7,9 +7,9 @@ Please, refer to [official documentation](https://goteleport.com/docs/admin-guid
## Development
1. Install [`protobuf`](https://grpc.io/docs/protoc-installation/).
-2. Install [`protoc-gen-terraform`](https://github.com/gravitational/protoc-gen-terraform) @v3.0.0.
+2. Install [`protoc-gen-terraform`](https://github.com/gravitational/protoc-gen-terraform) @v3.0.2.
- ```go install github.com/gravitational/protoc-gen-terraform@c91cc3ef4d7d0046c36cb96b1cd337e466c61225```
+ ```go install github.com/gravitational/protoc-gen-terraform/v3@v3.0.2```
3. Install [`Terraform`](https://learn.hashicorp.com/tutorials/terraform/install-cli) v1.1.0+. Alternatively, you can use [`tfenv`](https://github.com/tfutils/tfenv). Please note that on Mac M1 you need to specify `TFENV_ARCH` (ex: `TFENV_ARCH=arm64 tfenv install 1.1.6`).
diff --git a/integrations/terraform/examples/resources/teleport_workload_identity/resource.tf b/integrations/terraform/examples/resources/teleport_workload_identity/resource.tf
index e48ab1e5d0dd2..34dee932f430f 100644
--- a/integrations/terraform/examples/resources/teleport_workload_identity/resource.tf
+++ b/integrations/terraform/examples/resources/teleport_workload_identity/resource.tf
@@ -9,7 +9,9 @@ resource "teleport_workload_identity" "example" {
{
conditions = [{
attribute = "user.name"
- equals = "noah"
+ eq = {
+ value = "my-user"
+ }
}]
}
]
diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod
index d3240ffff8135..5222dc914a105 100644
--- a/integrations/terraform/go.mod
+++ b/integrations/terraform/go.mod
@@ -21,7 +21,6 @@ require (
github.com/hashicorp/terraform-plugin-log v0.9.0
github.com/hashicorp/terraform-plugin-sdk/v2 v2.10.1
github.com/jonboulle/clockwork v0.4.0
- github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.10.0
google.golang.org/grpc v1.69.2
google.golang.org/protobuf v1.36.2
@@ -307,6 +306,7 @@ require (
github.com/shirou/gopsutil/v4 v4.24.12 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/sijms/go-ora/v2 v2.8.22 // indirect
+ github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/spf13/cobra v1.8.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum
index 106e4e41c759b..da4bca430e263 100644
--- a/integrations/terraform/go.sum
+++ b/integrations/terraform/go.sum
@@ -798,6 +798,8 @@ github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.58 h1:
github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.58/go.mod h1:1FDesv+tfF2w5mRnLQbB8P33BPfxrngXtfNcdnrtmjw=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 h1:kqOrpojG71DxJm/KDPO+Z/y1phm1JlC8/iT+5XRmAn8=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22/go.mod h1:NtSFajXVVL8TA2QNngagVZmUtXciyrHOt7xgz4faS/M=
+github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 h1:fo+GuZNME9oGDc7VY+EBT+oCrco6RjRgUp1bKTcaHrU=
+github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2/go.mod h1:fnqb94UO6YCjBIic4WaqDYkNVAEFWOWiReVHitBBWW0=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45 h1:ZxB8WFVYwolhDZxuZXoesHkl+L9cXLWy0K/G0QkNATc=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45/go.mod h1:1krrbyoFFDqaNldmltPTP+mK3sAXLHPoaFtISOw2Hkk=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI=
diff --git a/integrations/terraform/provider/errors.go b/integrations/terraform/provider/errors.go
index d31715366d192..6c0f838b474bf 100644
--- a/integrations/terraform/provider/errors.go
+++ b/integrations/terraform/provider/errors.go
@@ -17,9 +17,11 @@ limitations under the License.
package provider
import (
+ "context"
+ "log/slog"
+
"github.com/gravitational/trace"
"github.com/hashicorp/terraform-plugin-framework/diag"
- log "github.com/sirupsen/logrus"
)
// diagFromWrappedErr wraps error with additional information
@@ -43,7 +45,7 @@ func diagFromWrappedErr(summary string, err error, kind string) diag.Diagnostic
// diagFromErr converts error to diag.Diagnostics. If logging level is debug, provides trace.DebugReport instead of short text.
func diagFromErr(summary string, err error) diag.Diagnostic {
- if log.GetLevel() >= log.DebugLevel {
+ if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
return diag.NewErrorDiagnostic(err.Error(), trace.DebugReport(err))
}
diff --git a/integrations/terraform/provider/provider.go b/integrations/terraform/provider/provider.go
index 13b20d20c434f..99d460a49f806 100644
--- a/integrations/terraform/provider/provider.go
+++ b/integrations/terraform/provider/provider.go
@@ -19,6 +19,7 @@ package provider
import (
"context"
"fmt"
+ "log/slog"
"net"
"os"
"strconv"
@@ -29,13 +30,13 @@ import (
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/tfsdk"
"github.com/hashicorp/terraform-plugin-framework/types"
- log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/grpclog"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/lib/utils"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
)
const (
@@ -305,7 +306,7 @@ func (p *Provider) Configure(ctx context.Context, req tfsdk.ConfigureProviderReq
return
}
- log.WithFields(log.Fields{"addr": addr}).Debug("Using Teleport address")
+ slog.DebugContext(ctx, "Using Teleport address", "addr", addr)
dialTimeoutDuration, err := time.ParseDuration(dialTimeoutDurationStr)
if err != nil {
@@ -393,7 +394,7 @@ func (p *Provider) Configure(ctx context.Context, req tfsdk.ConfigureProviderReq
// checkTeleportVersion ensures that Teleport version is at least minServerVersion
func (p *Provider) checkTeleportVersion(ctx context.Context, client *client.Client, resp *tfsdk.ConfigureProviderResponse) bool {
- log.Debug("Checking Teleport server version")
+ slog.DebugContext(ctx, "Checking Teleport server version")
pong, err := client.Ping(ctx)
if err != nil {
if trace.IsNotImplemented(err) {
@@ -403,13 +404,13 @@ func (p *Provider) checkTeleportVersion(ctx context.Context, client *client.Clie
)
return false
}
- log.WithError(err).Debug("Teleport version check error!")
+ slog.DebugContext(ctx, "Teleport version check error", "error", err)
resp.Diagnostics.AddError("Unable to get Teleport server version!", "Unable to get Teleport server version!")
return false
}
err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion)
if err != nil {
- log.WithError(err).Debug("Teleport version check error!")
+ slog.DebugContext(ctx, "Teleport version check error", "error", err)
resp.Diagnostics.AddError("Teleport version check error!", err.Error())
return false
}
@@ -447,7 +448,7 @@ func (p *Provider) validateAddr(addr string, resp *tfsdk.ConfigureProviderRespon
_, _, err := net.SplitHostPort(addr)
if err != nil {
- log.WithField("addr", addr).WithError(err).Debug("Teleport address format error!")
+ slog.DebugContext(context.Background(), "Teleport address format error", "error", err, "addr", addr)
resp.Diagnostics.AddError(
"Invalid Teleport address format",
fmt.Sprintf("Teleport address must be specified as host:port. Got %q", addr),
@@ -461,20 +462,32 @@ func (p *Provider) validateAddr(addr string, resp *tfsdk.ConfigureProviderRespon
// configureLog configures logging
func (p *Provider) configureLog() {
+ level := slog.LevelError
// Get Terraform log level
- level, err := log.ParseLevel(os.Getenv("TF_LOG"))
- if err != nil {
- log.SetLevel(log.ErrorLevel)
- } else {
- log.SetLevel(level)
+ switch strings.ToLower(os.Getenv("TF_LOG")) {
+ case "panic", "fatal", "error":
+ level = slog.LevelError
+ case "warn", "warning":
+ level = slog.LevelWarn
+ case "info":
+ level = slog.LevelInfo
+ case "debug":
+ level = slog.LevelDebug
+ case "trace":
+ level = logutils.TraceLevel
}
- log.SetFormatter(&log.TextFormatter{})
+ _, _, err := logutils.Initialize(logutils.Config{
+ Severity: level.String(),
+ Format: "text",
+ })
+ if err != nil {
+ return
+ }
// Show GRPC debug logs only if TF_LOG=DEBUG
- if log.GetLevel() >= log.DebugLevel {
- l := grpclog.NewLoggerV2(log.StandardLogger().Out, log.StandardLogger().Out, log.StandardLogger().Out)
- grpclog.SetLoggerV2(l)
+ if level <= slog.LevelDebug {
+ grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr))
}
}
diff --git a/integrations/terraform/testlib/fixtures/workload_identity_0_create.tf b/integrations/terraform/testlib/fixtures/workload_identity_0_create.tf
index b5d0ebe8aae08..a506ee5773d06 100644
--- a/integrations/terraform/testlib/fixtures/workload_identity_0_create.tf
+++ b/integrations/terraform/testlib/fixtures/workload_identity_0_create.tf
@@ -9,7 +9,9 @@ resource "teleport_workload_identity" "test" {
{
conditions = [{
attribute = "user.name"
- equals = "foo"
+ eq = {
+ value = "foo"
+ }
}]
}
]
diff --git a/integrations/terraform/testlib/fixtures/workload_identity_1_update.tf b/integrations/terraform/testlib/fixtures/workload_identity_1_update.tf
index cced0a4f8ecdd..bb64491258471 100644
--- a/integrations/terraform/testlib/fixtures/workload_identity_1_update.tf
+++ b/integrations/terraform/testlib/fixtures/workload_identity_1_update.tf
@@ -9,7 +9,9 @@ resource "teleport_workload_identity" "test" {
{
conditions = [{
attribute = "user.name"
- equals = "foo"
+ eq = {
+ value = "foo"
+ }
}]
}
]
diff --git a/integrations/terraform/testlib/workload_identity_test.go b/integrations/terraform/testlib/workload_identity_test.go
index 1e6d84cf6feb9..3e6d5a6ca4342 100644
--- a/integrations/terraform/testlib/workload_identity_test.go
+++ b/integrations/terraform/testlib/workload_identity_test.go
@@ -55,7 +55,7 @@ func (s *TerraformSuiteOSS) TestWorkloadIdentity() {
resource.TestCheckResourceAttr(name, "kind", "workload_identity"),
resource.TestCheckResourceAttr(name, "spec.spiffe.id", "/test"),
resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.attribute", "user.name"),
- resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.equals", "foo"),
+ resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.eq.value", "foo"),
),
},
{
@@ -68,7 +68,7 @@ func (s *TerraformSuiteOSS) TestWorkloadIdentity() {
resource.TestCheckResourceAttr(name, "kind", "workload_identity"),
resource.TestCheckResourceAttr(name, "spec.spiffe.id", "/test/updated"),
resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.attribute", "user.name"),
- resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.equals", "foo"),
+ resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.eq.value", "foo"),
),
},
{
@@ -101,7 +101,11 @@ func (s *TerraformSuiteOSS) TestImportWorkloadIdentity() {
Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
{
Attribute: "user.name",
- Equals: "foo",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "foo",
+ },
+ },
},
},
},
@@ -133,7 +137,7 @@ func (s *TerraformSuiteOSS) TestImportWorkloadIdentity() {
require.Equal(t, types.KindWorkloadIdentity, state[0].Attributes["kind"])
require.Equal(t, "/test", state[0].Attributes["spec.spiffe.id"])
require.Equal(t, "user.name", state[0].Attributes["spec.rules.allow.0.conditions.0.attribute"])
- require.Equal(t, "foo", state[0].Attributes["spec.rules.allow.0.conditions.0.equals"])
+ require.Equal(t, "foo", state[0].Attributes["spec.rules.allow.0.conditions.0.eq.value"])
return nil
},
diff --git a/integrations/terraform/tfschema/workloadidentity/v1/resource_terraform.go b/integrations/terraform/tfschema/workloadidentity/v1/resource_terraform.go
index 5a76525cde345..27c4412b764de 100644
--- a/integrations/terraform/tfschema/workloadidentity/v1/resource_terraform.go
+++ b/integrations/terraform/tfschema/workloadidentity/v1/resource_terraform.go
@@ -107,10 +107,41 @@ func GenSchemaWorkloadIdentity(ctx context.Context) (github_com_hashicorp_terraf
Optional: true,
Type: github_com_hashicorp_terraform_plugin_framework_types.StringType,
},
- "equals": {
- Description: "An exact string that the attribute must match.",
+ "eq": {
+ Attributes: github_com_hashicorp_terraform_plugin_framework_tfsdk.SingleNestedAttributes(map[string]github_com_hashicorp_terraform_plugin_framework_tfsdk.Attribute{"value": {
+ Description: "The value to compare the attribute against.",
+ Optional: true,
+ Type: github_com_hashicorp_terraform_plugin_framework_types.StringType,
+ }}),
+ Description: "The attribute casted to a string must be equal to the value.",
+ Optional: true,
+ },
+ "in": {
+ Attributes: github_com_hashicorp_terraform_plugin_framework_tfsdk.SingleNestedAttributes(map[string]github_com_hashicorp_terraform_plugin_framework_tfsdk.Attribute{"values": {
+ Description: "The list of values to compare the attribute against.",
+ Optional: true,
+ Type: github_com_hashicorp_terraform_plugin_framework_types.ListType{ElemType: github_com_hashicorp_terraform_plugin_framework_types.StringType},
+ }}),
+ Description: "The attribute casted to a string must be in the list of values.",
+ Optional: true,
+ },
+ "not_eq": {
+ Attributes: github_com_hashicorp_terraform_plugin_framework_tfsdk.SingleNestedAttributes(map[string]github_com_hashicorp_terraform_plugin_framework_tfsdk.Attribute{"value": {
+ Description: "The value to compare the attribute against.",
+ Optional: true,
+ Type: github_com_hashicorp_terraform_plugin_framework_types.StringType,
+ }}),
+ Description: "The attribute casted to a string must not be equal to the value.",
+ Optional: true,
+ },
+ "not_in": {
+ Attributes: github_com_hashicorp_terraform_plugin_framework_tfsdk.SingleNestedAttributes(map[string]github_com_hashicorp_terraform_plugin_framework_tfsdk.Attribute{"values": {
+ Description: "The list of values to compare the attribute against.",
+ Optional: true,
+ Type: github_com_hashicorp_terraform_plugin_framework_types.ListType{ElemType: github_com_hashicorp_terraform_plugin_framework_types.StringType},
+ }}),
+ Description: "The attribute casted to a string must not be in the list of values.",
Optional: true,
- Type: github_com_hashicorp_terraform_plugin_framework_types.StringType,
},
}),
Description: "The conditions that must be met for this rule to be considered passed.",
@@ -408,6 +439,7 @@ func CopyWorkloadIdentityFromTerraform(_ context.Context, tf github_com_hashicor
tf := v
t = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition{}
obj := t
+ obj.Operator = nil
{
a, ok := tf.Attrs["attribute"]
if !ok {
@@ -426,19 +458,162 @@ func CopyWorkloadIdentityFromTerraform(_ context.Context, tf github_com_hashicor
}
}
{
- a, ok := tf.Attrs["equals"]
+ a, ok := tf.Attrs["eq"]
if !ok {
- diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.equals"})
+ diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq"})
} else {
- v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.Object)
if !ok {
- diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.equals", "github.com/hashicorp/terraform-plugin-framework/types.String"})
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq", "github.com/hashicorp/terraform-plugin-framework/types.Object"})
} else {
- var t string
if !v.Null && !v.Unknown {
- t = string(v.Value)
+ b := &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityConditionEq{}
+ obj.Operator = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_Eq{Eq: b}
+ obj := b
+ tf := v
+ {
+ a, ok := tf.Attrs["value"]
+ if !ok {
+ diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq.value"})
+ } else {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq.value", "github.com/hashicorp/terraform-plugin-framework/types.String"})
+ } else {
+ var t string
+ if !v.Null && !v.Unknown {
+ t = string(v.Value)
+ }
+ obj.Value = t
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ {
+ a, ok := tf.Attrs["not_eq"]
+ if !ok {
+ diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq"})
+ } else {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.Object)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq", "github.com/hashicorp/terraform-plugin-framework/types.Object"})
+ } else {
+ if !v.Null && !v.Unknown {
+ b := &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityConditionNotEq{}
+ obj.Operator = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotEq{NotEq: b}
+ obj := b
+ tf := v
+ {
+ a, ok := tf.Attrs["value"]
+ if !ok {
+ diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value"})
+ } else {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value", "github.com/hashicorp/terraform-plugin-framework/types.String"})
+ } else {
+ var t string
+ if !v.Null && !v.Unknown {
+ t = string(v.Value)
+ }
+ obj.Value = t
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ {
+ a, ok := tf.Attrs["in"]
+ if !ok {
+ diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.in"})
+ } else {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.Object)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in", "github.com/hashicorp/terraform-plugin-framework/types.Object"})
+ } else {
+ if !v.Null && !v.Unknown {
+ b := &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityConditionIn{}
+ obj.Operator = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_In{In: b}
+ obj := b
+ tf := v
+ {
+ a, ok := tf.Attrs["values"]
+ if !ok {
+ diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values"})
+ } else {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.List)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values", "github.com/hashicorp/terraform-plugin-framework/types.List"})
+ } else {
+ obj.Values = make([]string, len(v.Elems))
+ if !v.Null && !v.Unknown {
+ for k, a := range v.Elems {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values", "github_com_hashicorp_terraform_plugin_framework_types.String"})
+ } else {
+ var t string
+ if !v.Null && !v.Unknown {
+ t = string(v.Value)
+ }
+ obj.Values[k] = t
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ {
+ a, ok := tf.Attrs["not_in"]
+ if !ok {
+ diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in"})
+ } else {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.Object)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in", "github.com/hashicorp/terraform-plugin-framework/types.Object"})
+ } else {
+ if !v.Null && !v.Unknown {
+ b := &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityConditionNotIn{}
+ obj.Operator = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotIn{NotIn: b}
+ obj := b
+ tf := v
+ {
+ a, ok := tf.Attrs["values"]
+ if !ok {
+ diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values"})
+ } else {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.List)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", "github.com/hashicorp/terraform-plugin-framework/types.List"})
+ } else {
+ obj.Values = make([]string, len(v.Elems))
+ if !v.Null && !v.Unknown {
+ for k, a := range v.Elems {
+ v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", "github_com_hashicorp_terraform_plugin_framework_types.String"})
+ } else {
+ var t string
+ if !v.Null && !v.Unknown {
+ t = string(v.Value)
+ }
+ obj.Values[k] = t
+ }
+ }
+ }
+ }
+ }
+ }
}
- obj.Equals = t
}
}
}
@@ -984,25 +1159,297 @@ func CopyWorkloadIdentityToTerraform(ctx context.Context, obj *github_com_gravit
}
}
{
- t, ok := tf.AttrTypes["equals"]
+ a, ok := tf.AttrTypes["eq"]
if !ok {
- diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.equals"})
+ diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq"})
} else {
- v, ok := tf.Attrs["equals"].(github_com_hashicorp_terraform_plugin_framework_types.String)
+ obj, ok := obj.Operator.(*github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_Eq)
if !ok {
- i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil))
- if err != nil {
- diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.equals", err})
+ obj = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_Eq{}
+ }
+ o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ObjectType)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq", "github.com/hashicorp/terraform-plugin-framework/types.ObjectType"})
+ } else {
+ v, ok := tf.Attrs["eq"].(github_com_hashicorp_terraform_plugin_framework_types.Object)
+ if !ok {
+ v = github_com_hashicorp_terraform_plugin_framework_types.Object{
+
+ AttrTypes: o.AttrTypes,
+ Attrs: make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(o.AttrTypes)),
+ }
+ } else {
+ if v.Attrs == nil {
+ v.Attrs = make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(tf.AttrTypes))
+ }
}
- v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if obj.Eq == nil {
+ v.Null = true
+ } else {
+ obj := obj.Eq
+ tf := &v
+ {
+ t, ok := tf.AttrTypes["value"]
+ if !ok {
+ diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq.value"})
+ } else {
+ v, ok := tf.Attrs["value"].(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil))
+ if err != nil {
+ diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.eq.value", err})
+ }
+ v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq.value", "github.com/hashicorp/terraform-plugin-framework/types.String"})
+ }
+ v.Null = string(obj.Value) == ""
+ }
+ v.Value = string(obj.Value)
+ v.Unknown = false
+ tf.Attrs["value"] = v
+ }
+ }
+ }
+ v.Unknown = false
+ tf.Attrs["eq"] = v
+ }
+ }
+ }
+ {
+ a, ok := tf.AttrTypes["not_eq"]
+ if !ok {
+ diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq"})
+ } else {
+ obj, ok := obj.Operator.(*github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotEq)
+ if !ok {
+ obj = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotEq{}
+ }
+ o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ObjectType)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq", "github.com/hashicorp/terraform-plugin-framework/types.ObjectType"})
+ } else {
+ v, ok := tf.Attrs["not_eq"].(github_com_hashicorp_terraform_plugin_framework_types.Object)
if !ok {
- diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.equals", "github.com/hashicorp/terraform-plugin-framework/types.String"})
+ v = github_com_hashicorp_terraform_plugin_framework_types.Object{
+
+ AttrTypes: o.AttrTypes,
+ Attrs: make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(o.AttrTypes)),
+ }
+ } else {
+ if v.Attrs == nil {
+ v.Attrs = make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(tf.AttrTypes))
+ }
+ }
+ if obj.NotEq == nil {
+ v.Null = true
+ } else {
+ obj := obj.NotEq
+ tf := &v
+ {
+ t, ok := tf.AttrTypes["value"]
+ if !ok {
+ diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value"})
+ } else {
+ v, ok := tf.Attrs["value"].(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil))
+ if err != nil {
+ diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value", err})
+ }
+ v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value", "github.com/hashicorp/terraform-plugin-framework/types.String"})
+ }
+ v.Null = string(obj.Value) == ""
+ }
+ v.Value = string(obj.Value)
+ v.Unknown = false
+ tf.Attrs["value"] = v
+ }
+ }
}
- v.Null = string(obj.Equals) == ""
+ v.Unknown = false
+ tf.Attrs["not_eq"] = v
+ }
+ }
+ }
+ {
+ a, ok := tf.AttrTypes["in"]
+ if !ok {
+ diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.in"})
+ } else {
+ obj, ok := obj.Operator.(*github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_In)
+ if !ok {
+ obj = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_In{}
+ }
+ o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ObjectType)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in", "github.com/hashicorp/terraform-plugin-framework/types.ObjectType"})
+ } else {
+ v, ok := tf.Attrs["in"].(github_com_hashicorp_terraform_plugin_framework_types.Object)
+ if !ok {
+ v = github_com_hashicorp_terraform_plugin_framework_types.Object{
+
+ AttrTypes: o.AttrTypes,
+ Attrs: make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(o.AttrTypes)),
+ }
+ } else {
+ if v.Attrs == nil {
+ v.Attrs = make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(tf.AttrTypes))
+ }
+ }
+ if obj.In == nil {
+ v.Null = true
+ } else {
+ obj := obj.In
+ tf := &v
+ {
+ a, ok := tf.AttrTypes["values"]
+ if !ok {
+ diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values"})
+ } else {
+ o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ListType)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values", "github.com/hashicorp/terraform-plugin-framework/types.ListType"})
+ } else {
+ c, ok := tf.Attrs["values"].(github_com_hashicorp_terraform_plugin_framework_types.List)
+ if !ok {
+ c = github_com_hashicorp_terraform_plugin_framework_types.List{
+
+ ElemType: o.ElemType,
+ Elems: make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values)),
+ Null: true,
+ }
+ } else {
+ if c.Elems == nil {
+ c.Elems = make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values))
+ }
+ }
+ if obj.Values != nil {
+ t := o.ElemType
+ if len(obj.Values) != len(c.Elems) {
+ c.Elems = make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values))
+ }
+ for k, a := range obj.Values {
+ v, ok := tf.Attrs["values"].(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil))
+ if err != nil {
+ diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.in.values", err})
+ }
+ v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values", "github.com/hashicorp/terraform-plugin-framework/types.String"})
+ }
+ v.Null = string(a) == ""
+ }
+ v.Value = string(a)
+ v.Unknown = false
+ c.Elems[k] = v
+ }
+ if len(obj.Values) > 0 {
+ c.Null = false
+ }
+ }
+ c.Unknown = false
+ tf.Attrs["values"] = c
+ }
+ }
+ }
+ }
+ v.Unknown = false
+ tf.Attrs["in"] = v
+ }
+ }
+ }
+ {
+ a, ok := tf.AttrTypes["not_in"]
+ if !ok {
+ diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in"})
+ } else {
+ obj, ok := obj.Operator.(*github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotIn)
+ if !ok {
+ obj = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotIn{}
+ }
+ o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ObjectType)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in", "github.com/hashicorp/terraform-plugin-framework/types.ObjectType"})
+ } else {
+ v, ok := tf.Attrs["not_in"].(github_com_hashicorp_terraform_plugin_framework_types.Object)
+ if !ok {
+ v = github_com_hashicorp_terraform_plugin_framework_types.Object{
+
+ AttrTypes: o.AttrTypes,
+ Attrs: make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(o.AttrTypes)),
+ }
+ } else {
+ if v.Attrs == nil {
+ v.Attrs = make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(tf.AttrTypes))
+ }
+ }
+ if obj.NotIn == nil {
+ v.Null = true
+ } else {
+ obj := obj.NotIn
+ tf := &v
+ {
+ a, ok := tf.AttrTypes["values"]
+ if !ok {
+ diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values"})
+ } else {
+ o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ListType)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", "github.com/hashicorp/terraform-plugin-framework/types.ListType"})
+ } else {
+ c, ok := tf.Attrs["values"].(github_com_hashicorp_terraform_plugin_framework_types.List)
+ if !ok {
+ c = github_com_hashicorp_terraform_plugin_framework_types.List{
+
+ ElemType: o.ElemType,
+ Elems: make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values)),
+ Null: true,
+ }
+ } else {
+ if c.Elems == nil {
+ c.Elems = make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values))
+ }
+ }
+ if obj.Values != nil {
+ t := o.ElemType
+ if len(obj.Values) != len(c.Elems) {
+ c.Elems = make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values))
+ }
+ for k, a := range obj.Values {
+ v, ok := tf.Attrs["values"].(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil))
+ if err != nil {
+ diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", err})
+ }
+ v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String)
+ if !ok {
+ diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", "github.com/hashicorp/terraform-plugin-framework/types.String"})
+ }
+ v.Null = string(a) == ""
+ }
+ v.Value = string(a)
+ v.Unknown = false
+ c.Elems[k] = v
+ }
+ if len(obj.Values) > 0 {
+ c.Null = false
+ }
+ }
+ c.Unknown = false
+ tf.Attrs["values"] = c
+ }
+ }
+ }
+ }
+ v.Unknown = false
+ tf.Attrs["not_in"] = v
}
- v.Value = string(obj.Equals)
- v.Unknown = false
- tf.Attrs["equals"] = v
}
}
}
diff --git a/lib/auth/machineid/workloadidentityv1/decision.go b/lib/auth/machineid/workloadidentityv1/decision.go
index 4e959efe6ee12..ccbbde2967c90 100644
--- a/lib/auth/machineid/workloadidentityv1/decision.go
+++ b/lib/auth/machineid/workloadidentityv1/decision.go
@@ -176,8 +176,25 @@ ruleLoop:
if err != nil {
return trace.Wrap(err)
}
- if val != condition.Equals {
- continue ruleLoop
+ switch c := condition.Operator.(type) {
+ case *workloadidentityv1pb.WorkloadIdentityCondition_Eq:
+ if val != c.Eq.Value {
+ continue ruleLoop
+ }
+ case *workloadidentityv1pb.WorkloadIdentityCondition_NotEq:
+ if val == c.NotEq.Value {
+ continue ruleLoop
+ }
+ case *workloadidentityv1pb.WorkloadIdentityCondition_In:
+ if !slices.Contains(c.In.Values, val) {
+ continue ruleLoop
+ }
+ case *workloadidentityv1pb.WorkloadIdentityCondition_NotIn:
+ if slices.Contains(c.NotIn.Values, val) {
+ continue ruleLoop
+ }
+ default:
+ return trace.BadParameter("unsupported operator %T", c)
}
}
return nil
diff --git a/lib/auth/machineid/workloadidentityv1/decision_test.go b/lib/auth/machineid/workloadidentityv1/decision_test.go
index 5d00bf7595669..3d2b9ed4cff95 100644
--- a/lib/auth/machineid/workloadidentityv1/decision_test.go
+++ b/lib/auth/machineid/workloadidentityv1/decision_test.go
@@ -263,28 +263,285 @@ func Test_evaluateRules(t *testing.T) {
User: &workloadidentityv1pb.UserAttrs{
Name: "foo",
},
+ Workload: &workloadidentityv1pb.WorkloadAttrs{
+ Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
+ PodName: "pod1",
+ Namespace: "default",
+ },
+ },
+ }
+
+ var noMatchRule require.ErrorAssertionFunc = func(t require.TestingT, err error, i ...interface{}) {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "no matching rule found")
}
- wi := &workloadidentityv1pb.WorkloadIdentity{
- Kind: types.KindWorkloadIdentity,
- Version: types.V1,
- Metadata: &headerv1.Metadata{
- Name: "test",
+
+ tests := []struct {
+ name string
+ wid *workloadidentityv1pb.WorkloadIdentity
+ attrs *workloadidentityv1pb.Attrs
+ requireErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "no rules: pass",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{},
+ },
+ },
+ attrs: attrs,
+ requireErr: require.NoError,
+ },
+ {
+ name: "eq: pass",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "foo",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ attrs: attrs,
+ requireErr: require.NoError,
+ },
+ {
+ name: "eq: fail",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "not-foo",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ attrs: attrs,
+ requireErr: noMatchRule,
+ },
+ {
+ name: "not_eq: pass",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_NotEq{
+ NotEq: &workloadidentityv1pb.WorkloadIdentityConditionNotEq{
+ Value: "bar",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ attrs: attrs,
+ requireErr: require.NoError,
+ },
+ {
+ name: "not_eq: fail",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_NotEq{
+ NotEq: &workloadidentityv1pb.WorkloadIdentityConditionNotEq{
+ Value: "foo",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ attrs: attrs,
+ requireErr: noMatchRule,
+ },
+ {
+ name: "in: pass",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_In{
+ In: &workloadidentityv1pb.WorkloadIdentityConditionIn{
+ Values: []string{"bar", "foo"},
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ attrs: attrs,
+ requireErr: require.NoError,
},
- Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
- Rules: &workloadidentityv1pb.WorkloadIdentityRules{
- Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
- {
- Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ name: "in: fail",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
{
- Attribute: "user.name",
- Equals: "foo",
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_In{
+ In: &workloadidentityv1pb.WorkloadIdentityConditionIn{
+ Values: []string{"bar", "fizz"},
+ },
+ },
+ },
+ },
},
},
},
},
},
+ attrs: attrs,
+ requireErr: noMatchRule,
},
+ {
+ name: "not_in: pass",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_NotIn{
+ NotIn: &workloadidentityv1pb.WorkloadIdentityConditionNotIn{
+ Values: []string{"bar", "fizz"},
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ attrs: attrs,
+ requireErr: require.NoError,
+ },
+ {
+ name: "in: fail",
+ wid: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_NotIn{
+ NotIn: &workloadidentityv1pb.WorkloadIdentityConditionNotIn{
+ Values: []string{"bar", "foo"},
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ attrs: attrs,
+ requireErr: noMatchRule,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := evaluateRules(tt.wid, tt.attrs)
+ tt.requireErr(t, err)
+ })
}
- err := evaluateRules(wi, attrs)
- require.NoError(t, err)
}
diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go
index e5f23dc96216c..1ddf63bcf28d1 100644
--- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go
+++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go
@@ -187,7 +187,11 @@ func TestIssueWorkloadIdentityE2E(t *testing.T) {
Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
{
Attribute: "join.kubernetes.service_account.namespace",
- Equals: "my-namespace",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "my-namespace",
+ },
+ },
},
},
},
@@ -402,11 +406,19 @@ func TestIssueWorkloadIdentity(t *testing.T) {
Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
{
Attribute: "user.name",
- Equals: "dog",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "dog",
+ },
+ },
},
{
Attribute: "workload.kubernetes.namespace",
- Equals: "default",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "default",
+ },
+ },
},
},
},
@@ -768,7 +780,11 @@ func TestIssueWorkloadIdentities(t *testing.T) {
Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
{
Attribute: "workload.kubernetes.namespace",
- Equals: "default",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "default",
+ },
+ },
},
},
},
@@ -798,7 +814,11 @@ func TestIssueWorkloadIdentities(t *testing.T) {
Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
{
Attribute: "workload.kubernetes.namespace",
- Equals: "default",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "default",
+ },
+ },
},
},
},
diff --git a/lib/backend/memory/memory.go b/lib/backend/memory/memory.go
index cd00a6bb6efaa..4adb2b0779803 100644
--- a/lib/backend/memory/memory.go
+++ b/lib/backend/memory/memory.go
@@ -472,7 +472,7 @@ func (m *Memory) removeExpired() int {
}
m.heap.PopEl()
m.tree.Delete(item)
- m.logger.DebugContext(m.ctx, "Removed expired item.", "key", item.Key.String(), "epiry", item.Expires)
+ m.logger.DebugContext(m.ctx, "Removed expired item.", "key", item.Key.String(), "expiry", item.Expires)
removed++
event := backend.Event{
diff --git a/lib/client/api.go b/lib/client/api.go
index ed94462aa9c73..8b4c317265573 100644
--- a/lib/client/api.go
+++ b/lib/client/api.go
@@ -2853,7 +2853,7 @@ type execResult struct {
// sharedWriter is an [io.Writer] implementation that protects
// writes with a mutex. This allows a single [io.Writer] to be shared
-// by both logrus and slog without their output clobbering each other.
+// by multiple command runners.
type sharedWriter struct {
mu sync.Mutex
io.Writer
diff --git a/lib/cloud/aws/aws.go b/lib/cloud/aws/aws.go
index 27ea56321b7df..7361ff75f219c 100644
--- a/lib/cloud/aws/aws.go
+++ b/lib/cloud/aws/aws.go
@@ -22,12 +22,12 @@ import (
"slices"
"strings"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/memorydb"
"github.com/aws/aws-sdk-go/service/opensearchservice"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/coreos/go-semver/semver"
"github.com/gravitational/teleport/lib/services"
@@ -74,18 +74,51 @@ func IsOpenSearchDomainAvailable(domain *opensearchservice.DomainStatus) bool {
}
// IsRDSProxyAvailable checks if the RDS Proxy is available.
-func IsRDSProxyAvailable(dbProxy *rds.DBProxy) bool {
- return IsResourceAvailable(dbProxy, dbProxy.Status)
+func IsRDSProxyAvailable(dbProxy *rdstypes.DBProxy) bool {
+ switch dbProxy.Status {
+ case
+ rdstypes.DBProxyStatusAvailable,
+ rdstypes.DBProxyStatusModifying,
+ rdstypes.DBProxyStatusReactivating:
+ return true
+ case
+ rdstypes.DBProxyStatusCreating,
+ rdstypes.DBProxyStatusDeleting,
+ rdstypes.DBProxyStatusIncompatibleNetwork,
+ rdstypes.DBProxyStatusInsufficientResourceLimits,
+ rdstypes.DBProxyStatusSuspended,
+ rdstypes.DBProxyStatusSuspending:
+ return false
+ }
+ slog.WarnContext(context.Background(), "Assuming RDS Proxy with unknown status is available",
+ "status", dbProxy.Status,
+ )
+ return true
}
// IsRDSProxyCustomEndpointAvailable checks if the RDS Proxy custom endpoint is available.
-func IsRDSProxyCustomEndpointAvailable(customEndpoint *rds.DBProxyEndpoint) bool {
- return IsResourceAvailable(customEndpoint, customEndpoint.Status)
+func IsRDSProxyCustomEndpointAvailable(customEndpoint *rdstypes.DBProxyEndpoint) bool {
+ switch customEndpoint.Status {
+ case
+ rdstypes.DBProxyEndpointStatusAvailable,
+ rdstypes.DBProxyEndpointStatusModifying:
+ return true
+ case
+ rdstypes.DBProxyEndpointStatusCreating,
+ rdstypes.DBProxyEndpointStatusDeleting,
+ rdstypes.DBProxyEndpointStatusIncompatibleNetwork,
+ rdstypes.DBProxyEndpointStatusInsufficientResourceLimits:
+ return false
+ }
+ slog.WarnContext(context.Background(), "Assuming RDS Proxy custom endpoint with unknown status is available",
+ "status", customEndpoint.Status,
+ )
+ return true
}
// IsRDSInstanceSupported returns true if database supports IAM authentication.
// Currently, only MariaDB is being checked.
-func IsRDSInstanceSupported(instance *rds.DBInstance) bool {
+func IsRDSInstanceSupported(instance *rdstypes.DBInstance) bool {
// TODO(jakule): Check other engines.
if aws.StringValue(instance.Engine) != services.RDSEngineMariaDB {
return true
@@ -105,7 +138,7 @@ func IsRDSInstanceSupported(instance *rds.DBInstance) bool {
}
// IsRDSClusterSupported checks whether the Aurora cluster is supported.
-func IsRDSClusterSupported(cluster *rds.DBCluster) bool {
+func IsRDSClusterSupported(cluster *rdstypes.DBCluster) bool {
switch aws.StringValue(cluster.EngineMode) {
// Aurora Serverless v1 does NOT support IAM authentication.
// https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-serverless.html#aurora-serverless.limitations
@@ -129,7 +162,7 @@ func IsRDSClusterSupported(cluster *rds.DBCluster) bool {
}
// AuroraMySQLVersion extracts aurora mysql version from engine version
-func AuroraMySQLVersion(cluster *rds.DBCluster) string {
+func AuroraMySQLVersion(cluster *rdstypes.DBCluster) string {
// version guide: https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/AuroraMySQL.Updates.Versions.html
// a list of all the available versions: https://docs.aws.amazon.com/cli/latest/reference/rds/describe-db-engine-versions.html
//
@@ -154,7 +187,7 @@ func AuroraMySQLVersion(cluster *rds.DBCluster) string {
// for this DocumentDB cluster.
//
// https://docs.aws.amazon.com/documentdb/latest/developerguide/iam-identity-auth.html
-func IsDocumentDBClusterSupported(cluster *rds.DBCluster) bool {
+func IsDocumentDBClusterSupported(cluster *rdstypes.DBCluster) bool {
ver, err := semver.NewVersion(aws.StringValue(cluster.EngineVersion))
if err != nil {
slog.ErrorContext(context.Background(), "Failed to parse DocumentDB engine version", "version", aws.StringValue(cluster.EngineVersion))
diff --git a/lib/cloud/aws/tags_helpers.go b/lib/cloud/aws/tags_helpers.go
index 3e61bd6fc1a42..43f6ba48f61ca 100644
--- a/lib/cloud/aws/tags_helpers.go
+++ b/lib/cloud/aws/tags_helpers.go
@@ -24,14 +24,13 @@ import (
"slices"
ec2TypesV2 "github.com/aws/aws-sdk-go-v2/service/ec2/types"
- rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/memorydb"
"github.com/aws/aws-sdk-go/service/opensearchservice"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/redshiftserverless"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"golang.org/x/exp/maps"
@@ -43,11 +42,10 @@ import (
type ResourceTag interface {
// TODO Go generic does not allow access common fields yet. List all types
// here and use a type switch for now.
- rdsTypesV2.Tag |
+ rdstypes.Tag |
ec2TypesV2.Tag |
redshifttypes.Tag |
*ec2.Tag |
- *rds.Tag |
*elasticache.Tag |
*memorydb.Tag |
*redshiftserverless.Tag |
@@ -76,8 +74,6 @@ func TagsToLabels[Tag ResourceTag](tags []Tag) map[string]string {
func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) {
switch v := any(tag).(type) {
- case *rds.Tag:
- return aws.StringValue(v.Key), aws.StringValue(v.Value)
case *ec2.Tag:
return aws.StringValue(v.Key), aws.StringValue(v.Value)
case *elasticache.Tag:
@@ -86,7 +82,7 @@ func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) {
return aws.StringValue(v.Key), aws.StringValue(v.Value)
case *redshiftserverless.Tag:
return aws.StringValue(v.Key), aws.StringValue(v.Value)
- case rdsTypesV2.Tag:
+ case rdstypes.Tag:
return aws.StringValue(v.Key), aws.StringValue(v.Value)
case ec2TypesV2.Tag:
return aws.StringValue(v.Key), aws.StringValue(v.Value)
@@ -123,22 +119,3 @@ func LabelsToTags[T any, PT SettableTag[T]](labels map[string]string) (tags []*T
}
return
}
-
-// LabelsToRDSV2Tags converts labels into [rdsTypesV2.Tag] list.
-func LabelsToRDSV2Tags(labels map[string]string) []rdsTypesV2.Tag {
- keys := maps.Keys(labels)
- slices.Sort(keys)
-
- ret := make([]rdsTypesV2.Tag, 0, len(keys))
- for _, key := range keys {
- key := key
- value := labels[key]
-
- ret = append(ret, rdsTypesV2.Tag{
- Key: &key,
- Value: &value,
- })
- }
-
- return ret
-}
diff --git a/lib/cloud/aws/tags_helpers_test.go b/lib/cloud/aws/tags_helpers_test.go
index 228c477a316cb..0bc75677fefbd 100644
--- a/lib/cloud/aws/tags_helpers_test.go
+++ b/lib/cloud/aws/tags_helpers_test.go
@@ -22,10 +22,10 @@ import (
"testing"
rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/elasticache"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/stretchr/testify/require"
)
@@ -33,7 +33,7 @@ func TestTagsToLabels(t *testing.T) {
t.Parallel()
t.Run("rds", func(t *testing.T) {
- inputTags := []*rds.Tag{
+ inputTags := []rdstypes.Tag{
{
Key: aws.String("Env"),
Value: aws.String("dev"),
@@ -135,25 +135,4 @@ func TestLabelsToTags(t *testing.T) {
actualTags := LabelsToTags[elasticache.Tag](inputLabels)
require.Equal(t, expectTags, actualTags)
})
-
- t.Run("rdsV2", func(t *testing.T) {
- inputLabels := map[string]string{
- "labelB": "valueB",
- "labelA": "valueA",
- }
-
- expectTags := []rdsTypesV2.Tag{
- {
- Key: aws.String("labelA"),
- Value: aws.String("valueA"),
- },
- {
- Key: aws.String("labelB"),
- Value: aws.String("valueB"),
- },
- }
-
- actualTags := LabelsToRDSV2Tags(inputLabels)
- require.EqualValues(t, expectTags, actualTags)
- })
}
diff --git a/lib/cloud/awsconfig/awsconfig.go b/lib/cloud/awsconfig/awsconfig.go
index 7b1cabe5ffe75..245fe8a9a6b23 100644
--- a/lib/cloud/awsconfig/awsconfig.go
+++ b/lib/cloud/awsconfig/awsconfig.go
@@ -280,11 +280,11 @@ func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Confi
}
func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn STSClientProviderFunc) (aws.Config, error) {
- for _, r := range roles {
- cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
- }
if len(roles) > 0 {
- // no point caching every assumed role in the chain, we can just cache
+ for _, r := range roles {
+ cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
+ }
+ // No point caching every assumed role in the chain, we can just cache
// the last one.
cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, awsCredentialsCacheOptions)
if _, err := cfg.Credentials.Retrieve(ctx); err != nil {
diff --git a/lib/cloud/awstesthelpers/tags.go b/lib/cloud/awstesthelpers/tags.go
index 5e1f4aa0e0738..28bed6b973f0b 100644
--- a/lib/cloud/awstesthelpers/tags.go
+++ b/lib/cloud/awstesthelpers/tags.go
@@ -22,6 +22,7 @@ import (
"maps"
"slices"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
)
@@ -43,3 +44,22 @@ func LabelsToRedshiftTags(labels map[string]string) []redshifttypes.Tag {
return ret
}
+
+// LabelsToRDSTags converts labels into a [rdstypes.Tag] list.
+func LabelsToRDSTags(labels map[string]string) []rdstypes.Tag {
+ keys := slices.Collect(maps.Keys(labels))
+ slices.Sort(keys)
+
+ ret := make([]rdstypes.Tag, 0, len(keys))
+ for _, key := range keys {
+ key := key
+ value := labels[key]
+
+ ret = append(ret, rdstypes.Tag{
+ Key: &key,
+ Value: &value,
+ })
+ }
+
+ return ret
+}
diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go
index 99c2deb4001f0..cc50c98c1ba4f 100644
--- a/lib/cloud/clients.go
+++ b/lib/cloud/clients.go
@@ -39,8 +39,6 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
awssession "github.com/aws/aws-sdk-go/aws/session"
- "github.com/aws/aws-sdk-go/service/eks"
- "github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface"
"github.com/aws/aws-sdk-go/service/iam"
@@ -51,8 +49,6 @@ import (
"github.com/aws/aws-sdk-go/service/memorydb/memorydbiface"
"github.com/aws/aws-sdk-go/service/opensearchservice"
"github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface"
- "github.com/aws/aws-sdk-go/service/rds"
- "github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/aws/aws-sdk-go/service/redshiftserverless"
"github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface"
"github.com/aws/aws-sdk-go/service/s3"
@@ -111,8 +107,6 @@ type GCPClients interface {
type AWSClients interface {
// GetAWSSession returns AWS session for the specified region and any role(s).
GetAWSSession(ctx context.Context, region string, opts ...AWSOptionsFn) (*awssession.Session, error)
- // GetAWSRDSClient returns AWS RDS client for the specified region.
- GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error)
// GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region.
GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error)
// GetAWSElastiCacheClient returns AWS ElastiCache client for the specified region.
@@ -127,8 +121,6 @@ type AWSClients interface {
GetAWSIAMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (iamiface.IAMAPI, error)
// GetAWSSTSClient returns AWS STS client for the specified region.
GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error)
- // GetAWSEKSClient returns AWS EKS client for the specified region.
- GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error)
// GetAWSKMSClient returns AWS KMS client for the specified region.
GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error)
// GetAWSS3Client returns AWS S3 client.
@@ -504,15 +496,6 @@ func (c *cloudClients) GetAWSSession(ctx context.Context, region string, opts ..
return c.getAWSSessionForRole(ctx, region, options)
}
-// GetAWSRDSClient returns AWS RDS client for the specified region.
-func (c *cloudClients) GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) {
- session, err := c.GetAWSSession(ctx, region, opts...)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- return rds.New(session), nil
-}
-
// GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region.
func (c *cloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
@@ -585,15 +568,6 @@ func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts
return sts.New(session), nil
}
-// GetAWSEKSClient returns AWS EKS client for the specified region.
-func (c *cloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) {
- session, err := c.GetAWSSession(ctx, region, opts...)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- return eks.New(session), nil
-}
-
// GetAWSKMSClient returns AWS KMS client for the specified region.
func (c *cloudClients) GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
@@ -1018,8 +992,6 @@ var _ Clients = (*TestCloudClients)(nil)
// TestCloudClients are used in tests.
type TestCloudClients struct {
- RDS rdsiface.RDSAPI
- RDSPerRegion map[string]rdsiface.RDSAPI
RedshiftServerless redshiftserverlessiface.RedshiftServerlessAPI
ElastiCache elasticacheiface.ElastiCacheAPI
OpenSearch opensearchserviceiface.OpenSearchServiceAPI
@@ -1032,7 +1004,6 @@ type TestCloudClients struct {
GCPProjects gcp.ProjectsClient
GCPInstances gcp.InstancesClient
InstanceMetadata imds.Client
- EKS eksiface.EKSAPI
KMS kmsiface.KMSAPI
S3 s3iface.S3API
AzureMySQL azure.DBServersClient
@@ -1089,18 +1060,6 @@ func (c *TestCloudClients) getAWSSessionForRegion(region string) (*awssession.Se
})
}
-// GetAWSRDSClient returns AWS RDS client for the specified region.
-func (c *TestCloudClients) GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) {
- _, err := c.GetAWSSession(ctx, region, opts...)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- if len(c.RDSPerRegion) != 0 {
- return c.RDSPerRegion[region], nil
- }
- return c.RDS, nil
-}
-
// GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region.
func (c *TestCloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) {
_, err := c.GetAWSSession(ctx, region, opts...)
@@ -1173,15 +1132,6 @@ func (c *TestCloudClients) GetAWSSTSClient(ctx context.Context, region string, o
return c.STS, nil
}
-// GetAWSEKSClient returns AWS EKS client for the specified region.
-func (c *TestCloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) {
- _, err := c.GetAWSSession(ctx, region, opts...)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- return c.EKS, nil
-}
-
// GetAWSKMSClient returns AWS KMS client for the specified region.
func (c *TestCloudClients) GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error) {
_, err := c.GetAWSSession(ctx, region, opts...)
diff --git a/lib/cloud/mocks/aws.go b/lib/cloud/mocks/aws.go
index ceb50bd822cc2..9ba40628e3a92 100644
--- a/lib/cloud/mocks/aws.go
+++ b/lib/cloud/mocks/aws.go
@@ -28,8 +28,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
- "github.com/aws/aws-sdk-go/service/eks"
- "github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
@@ -288,86 +286,3 @@ func (m *IAMErrorMock) PutUserPolicyWithContext(ctx aws.Context, input *iam.PutU
}
return nil, trace.AccessDenied("unauthorized")
}
-
-// EKSMock is a mock EKS client.
-type EKSMock struct {
- eksiface.EKSAPI
- Clusters []*eks.Cluster
- AccessEntries []*eks.AccessEntry
- AssociatedPolicies []*eks.AssociatedAccessPolicy
- Notify chan struct{}
-}
-
-func (e *EKSMock) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) {
- defer func() {
- if e.Notify != nil {
- e.Notify <- struct{}{}
- }
- }()
- for _, cluster := range e.Clusters {
- if aws.StringValue(req.Name) == aws.StringValue(cluster.Name) {
- return &eks.DescribeClusterOutput{Cluster: cluster}, nil
- }
- }
- return nil, trace.NotFound("cluster %v not found", aws.StringValue(req.Name))
-}
-
-func (e *EKSMock) ListClustersPagesWithContext(_ aws.Context, _ *eks.ListClustersInput, f func(*eks.ListClustersOutput, bool) bool, _ ...request.Option) error {
- defer func() {
- if e.Notify != nil {
- e.Notify <- struct{}{}
- }
- }()
- clusters := make([]*string, 0, len(e.Clusters))
- for _, cluster := range e.Clusters {
- clusters = append(clusters, cluster.Name)
- }
- f(&eks.ListClustersOutput{
- Clusters: clusters,
- }, true)
- return nil
-}
-
-func (e *EKSMock) ListAccessEntriesPagesWithContext(_ aws.Context, _ *eks.ListAccessEntriesInput, f func(*eks.ListAccessEntriesOutput, bool) bool, _ ...request.Option) error {
- defer func() {
- if e.Notify != nil {
- e.Notify <- struct{}{}
- }
- }()
- accessEntries := make([]*string, 0, len(e.Clusters))
- for _, a := range e.AccessEntries {
- accessEntries = append(accessEntries, a.PrincipalArn)
- }
- f(&eks.ListAccessEntriesOutput{
- AccessEntries: accessEntries,
- }, true)
- return nil
-}
-
-func (e *EKSMock) DescribeAccessEntryWithContext(_ aws.Context, req *eks.DescribeAccessEntryInput, _ ...request.Option) (*eks.DescribeAccessEntryOutput, error) {
- defer func() {
- if e.Notify != nil {
- e.Notify <- struct{}{}
- }
- }()
- for _, a := range e.AccessEntries {
- if aws.StringValue(req.PrincipalArn) == aws.StringValue(a.PrincipalArn) && aws.StringValue(a.ClusterName) == aws.StringValue(req.ClusterName) {
- return &eks.DescribeAccessEntryOutput{AccessEntry: a}, nil
- }
- }
- return nil, trace.NotFound("access entry %v not found", aws.StringValue(req.PrincipalArn))
-}
-
-func (e *EKSMock) ListAssociatedAccessPoliciesPagesWithContext(_ aws.Context, _ *eks.ListAssociatedAccessPoliciesInput, f func(*eks.ListAssociatedAccessPoliciesOutput, bool) bool, _ ...request.Option) error {
- defer func() {
- if e.Notify != nil {
- e.Notify <- struct{}{}
- }
- }()
-
- f(&eks.ListAssociatedAccessPoliciesOutput{
- AssociatedAccessPolicies: e.AssociatedPolicies,
- }, true)
- return nil
-
-}
diff --git a/lib/cloud/mocks/aws_config.go b/lib/cloud/mocks/aws_config.go
index b52dfbd36d74a..d148e9512c8d4 100644
--- a/lib/cloud/mocks/aws_config.go
+++ b/lib/cloud/mocks/aws_config.go
@@ -29,21 +29,26 @@ import (
)
type AWSConfigProvider struct {
+ Err error
STSClient *STSClient
OIDCIntegrationClient awsconfig.OIDCIntegrationClient
}
func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns ...awsconfig.OptionsFn) (aws.Config, error) {
+ if f.Err != nil {
+ return aws.Config{}, f.Err
+ }
+
stsClt := f.STSClient
if stsClt == nil {
stsClt = &STSClient{}
}
- optFns = append(optFns,
+ optFns = append([]awsconfig.OptionsFn{
awsconfig.WithOIDCIntegrationClient(f.OIDCIntegrationClient),
awsconfig.WithSTSClientProvider(
newAssumeRoleClientProviderFunc(stsClt),
),
- )
+ }, optFns...)
return awsconfig.GetConfig(ctx, region, optFns...)
}
diff --git a/lib/cloud/mocks/aws_rds.go b/lib/cloud/mocks/aws_rds.go
index 50130d668f5c0..9338b8330dc5f 100644
--- a/lib/cloud/mocks/aws_rds.go
+++ b/lib/cloud/mocks/aws_rds.go
@@ -19,159 +19,156 @@
package mocks
import (
+ "context"
"fmt"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdsv2 "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
- "github.com/aws/aws-sdk-go/aws/request"
- "github.com/aws/aws-sdk-go/service/rds"
- "github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/google/uuid"
"github.com/gravitational/trace"
- libcloudaws "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/awstesthelpers"
)
-// RDSMock mocks AWS RDS API.
-type RDSMock struct {
- rdsiface.RDSAPI
- DBInstances []*rds.DBInstance
- DBClusters []*rds.DBCluster
- DBProxies []*rds.DBProxy
- DBProxyEndpoints []*rds.DBProxyEndpoint
- DBEngineVersions []*rds.DBEngineVersion
+type RDSClient struct {
+ Unauth bool
+
+ DBInstances []rdstypes.DBInstance
+ DBClusters []rdstypes.DBCluster
+ DBProxies []rdstypes.DBProxy
+ DBProxyEndpoints []rdstypes.DBProxyEndpoint
+ DBEngineVersions []rdstypes.DBEngineVersion
}
-func (m *RDSMock) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) {
- if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil {
+func (c *RDSClient) DescribeDBInstances(_ context.Context, input *rdsv2.DescribeDBInstancesInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBInstancesOutput, error) {
+ if c.Unauth {
+ return nil, trace.AccessDenied("unauthorized")
+ }
+
+ if err := checkEngineFilters(input.Filters, c.DBEngineVersions); err != nil {
return nil, trace.Wrap(err)
}
- instances, err := applyInstanceFilters(m.DBInstances, input.Filters)
+ instances, err := applyInstanceFilters(c.DBInstances, input.Filters)
if err != nil {
return nil, trace.Wrap(err)
}
if aws.StringValue(input.DBInstanceIdentifier) == "" {
- return &rds.DescribeDBInstancesOutput{
+ return &rdsv2.DescribeDBInstancesOutput{
DBInstances: instances,
}, nil
}
for _, instance := range instances {
if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) {
- return &rds.DescribeDBInstancesOutput{
- DBInstances: []*rds.DBInstance{instance},
+ return &rdsv2.DescribeDBInstancesOutput{
+ DBInstances: []rdstypes.DBInstance{instance},
}, nil
}
}
return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier))
}
-func (m *RDSMock) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error {
- if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil {
- return trace.Wrap(err)
- }
- instances, err := applyInstanceFilters(m.DBInstances, input.Filters)
- if err != nil {
- return trace.Wrap(err)
+func (c *RDSClient) DescribeDBClusters(_ context.Context, input *rdsv2.DescribeDBClustersInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBClustersOutput, error) {
+ if c.Unauth {
+ return nil, trace.AccessDenied("unauthorized")
}
- fn(&rds.DescribeDBInstancesOutput{
- DBInstances: instances,
- }, true)
- return nil
-}
-func (m *RDSMock) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) {
- if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil {
+ if err := checkEngineFilters(input.Filters, c.DBEngineVersions); err != nil {
return nil, trace.Wrap(err)
}
- clusters, err := applyClusterFilters(m.DBClusters, input.Filters)
+ clusters, err := applyClusterFilters(c.DBClusters, input.Filters)
if err != nil {
return nil, trace.Wrap(err)
}
if aws.StringValue(input.DBClusterIdentifier) == "" {
- return &rds.DescribeDBClustersOutput{
+ return &rdsv2.DescribeDBClustersOutput{
DBClusters: clusters,
}, nil
}
for _, cluster := range clusters {
if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) {
- return &rds.DescribeDBClustersOutput{
- DBClusters: []*rds.DBCluster{cluster},
+ return &rdsv2.DescribeDBClustersOutput{
+ DBClusters: []rdstypes.DBCluster{cluster},
}, nil
}
}
return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier))
}
-func (m *RDSMock) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error {
- if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil {
- return trace.Wrap(err)
- }
- clusters, err := applyClusterFilters(m.DBClusters, input.Filters)
- if err != nil {
- return trace.Wrap(err)
+func (c *RDSClient) ModifyDBInstance(ctx context.Context, input *rdsv2.ModifyDBInstanceInput, optFns ...func(*rdsv2.Options)) (*rdsv2.ModifyDBInstanceOutput, error) {
+ if c.Unauth {
+ return nil, trace.AccessDenied("unauthorized")
}
- fn(&rds.DescribeDBClustersOutput{
- DBClusters: clusters,
- }, true)
- return nil
-}
-func (m *RDSMock) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) {
- for i, instance := range m.DBInstances {
+ for i, instance := range c.DBInstances {
if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) {
if aws.BoolValue(input.EnableIAMDatabaseAuthentication) {
- m.DBInstances[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true)
+ c.DBInstances[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true)
}
- return &rds.ModifyDBInstanceOutput{
- DBInstance: m.DBInstances[i],
+ return &rdsv2.ModifyDBInstanceOutput{
+ DBInstance: &c.DBInstances[i],
}, nil
}
}
return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier))
}
-func (m *RDSMock) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) {
- for i, cluster := range m.DBClusters {
+func (c *RDSClient) ModifyDBCluster(ctx context.Context, input *rdsv2.ModifyDBClusterInput, optFns ...func(*rdsv2.Options)) (*rdsv2.ModifyDBClusterOutput, error) {
+ if c.Unauth {
+ return nil, trace.AccessDenied("unauthorized")
+ }
+
+ for i, cluster := range c.DBClusters {
if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) {
if aws.BoolValue(input.EnableIAMDatabaseAuthentication) {
- m.DBClusters[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true)
+ c.DBClusters[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true)
}
- return &rds.ModifyDBClusterOutput{
- DBCluster: m.DBClusters[i],
+ return &rdsv2.ModifyDBClusterOutput{
+ DBCluster: &c.DBClusters[i],
}, nil
}
}
return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier))
}
-func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) {
+func (c *RDSClient) DescribeDBProxies(_ context.Context, input *rdsv2.DescribeDBProxiesInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBProxiesOutput, error) {
+ if c.Unauth {
+ return nil, trace.AccessDenied("unauthorized")
+ }
+
if aws.StringValue(input.DBProxyName) == "" {
- return &rds.DescribeDBProxiesOutput{
- DBProxies: m.DBProxies,
+ return &rdsv2.DescribeDBProxiesOutput{
+ DBProxies: c.DBProxies,
}, nil
}
- for _, dbProxy := range m.DBProxies {
+ for _, dbProxy := range c.DBProxies {
if aws.StringValue(dbProxy.DBProxyName) == aws.StringValue(input.DBProxyName) {
- return &rds.DescribeDBProxiesOutput{
- DBProxies: []*rds.DBProxy{dbProxy},
+ return &rdsv2.DescribeDBProxiesOutput{
+ DBProxies: []rdstypes.DBProxy{dbProxy},
}, nil
}
}
return nil, trace.NotFound("proxy %v not found", aws.StringValue(input.DBProxyName))
}
-func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) {
+func (c *RDSClient) DescribeDBProxyEndpoints(_ context.Context, input *rdsv2.DescribeDBProxyEndpointsInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBProxyEndpointsOutput, error) {
+ if c.Unauth {
+ return nil, trace.AccessDenied("unauthorized")
+ }
+
inputProxyName := aws.StringValue(input.DBProxyName)
inputProxyEndpointName := aws.StringValue(input.DBProxyEndpointName)
if inputProxyName == "" && inputProxyEndpointName == "" {
- return &rds.DescribeDBProxyEndpointsOutput{
- DBProxyEndpoints: m.DBProxyEndpoints,
+ return &rdsv2.DescribeDBProxyEndpointsOutput{
+ DBProxyEndpoints: c.DBProxyEndpoints,
}, nil
}
- var endpoints []*rds.DBProxyEndpoint
- for _, dbProxyEndpoiont := range m.DBProxyEndpoints {
+ var endpoints []rdstypes.DBProxyEndpoint
+ for _, dbProxyEndpoiont := range c.DBProxyEndpoints {
if inputProxyEndpointName != "" &&
inputProxyEndpointName != aws.StringValue(dbProxyEndpoiont.DBProxyEndpointName) {
continue
@@ -187,114 +184,15 @@ func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rd
if len(endpoints) == 0 {
return nil, trace.NotFound("proxy endpoint %v not found", aws.StringValue(input.DBProxyEndpointName))
}
- return &rds.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil
-}
-
-func (m *RDSMock) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error {
- fn(&rds.DescribeDBProxiesOutput{
- DBProxies: m.DBProxies,
- }, true)
- return nil
+ return &rdsv2.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil
}
-func (m *RDSMock) DescribeDBProxyEndpointsPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, fn func(*rds.DescribeDBProxyEndpointsOutput, bool) bool, options ...request.Option) error {
- fn(&rds.DescribeDBProxyEndpointsOutput{
- DBProxyEndpoints: m.DBProxyEndpoints,
- }, true)
- return nil
-}
-
-func (m *RDSMock) ListTagsForResourceWithContext(ctx aws.Context, input *rds.ListTagsForResourceInput, options ...request.Option) (*rds.ListTagsForResourceOutput, error) {
+func (c *RDSClient) ListTagsForResource(context.Context, *rds.ListTagsForResourceInput, ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) {
return &rds.ListTagsForResourceOutput{}, nil
}
-// RDSMockUnauth is a mock RDS client that returns access denied to each call.
-type RDSMockUnauth struct {
- rdsiface.RDSAPI
-}
-
-func (m *RDSMockUnauth) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) {
- return nil, trace.AccessDenied("unauthorized")
-}
-
-func (m *RDSMockUnauth) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) {
- return nil, trace.AccessDenied("unauthorized")
-}
-
-func (m *RDSMockUnauth) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) {
- return nil, trace.AccessDenied("unauthorized")
-}
-
-func (m *RDSMockUnauth) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) {
- return nil, trace.AccessDenied("unauthorized")
-}
-
-func (m *RDSMockUnauth) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error {
- return trace.AccessDenied("unauthorized")
-}
-
-func (m *RDSMockUnauth) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error {
- return trace.AccessDenied("unauthorized")
-}
-
-func (m *RDSMockUnauth) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) {
- return nil, trace.AccessDenied("unauthorized")
-}
-
-func (m *RDSMockUnauth) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) {
- return nil, trace.AccessDenied("unauthorized")
-}
-
-func (m *RDSMockUnauth) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error {
- return trace.AccessDenied("unauthorized")
-}
-
-// RDSMockByDBType is a mock RDS client that mocks API calls by DB type
-type RDSMockByDBType struct {
- rdsiface.RDSAPI
- DBInstances rdsiface.RDSAPI
- DBClusters rdsiface.RDSAPI
- DBProxies rdsiface.RDSAPI
-}
-
-func (m *RDSMockByDBType) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) {
- return m.DBInstances.DescribeDBInstancesWithContext(ctx, input, options...)
-}
-
-func (m *RDSMockByDBType) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) {
- return m.DBInstances.ModifyDBInstanceWithContext(ctx, input, options...)
-}
-
-func (m *RDSMockByDBType) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error {
- return m.DBInstances.DescribeDBInstancesPagesWithContext(ctx, input, fn, options...)
-}
-
-func (m *RDSMockByDBType) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) {
- return m.DBClusters.DescribeDBClustersWithContext(ctx, input, options...)
-}
-
-func (m *RDSMockByDBType) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) {
- return m.DBClusters.ModifyDBClusterWithContext(ctx, input, options...)
-}
-
-func (m *RDSMockByDBType) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error {
- return m.DBClusters.DescribeDBClustersPagesWithContext(aws, input, fn, options...)
-}
-
-func (m *RDSMockByDBType) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) {
- return m.DBProxies.DescribeDBProxiesWithContext(ctx, input, options...)
-}
-
-func (m *RDSMockByDBType) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) {
- return m.DBProxies.DescribeDBProxyEndpointsWithContext(ctx, input, options...)
-}
-
-func (m *RDSMockByDBType) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error {
- return m.DBProxies.DescribeDBProxiesPagesWithContext(ctx, input, fn, options...)
-}
-
// checkEngineFilters checks RDS filters to detect unrecognized engine filters.
-func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVersion) error {
+func checkEngineFilters(filters []rdstypes.Filter, engineVersions []rdstypes.DBEngineVersion) error {
if len(filters) == 0 {
return nil
}
@@ -307,8 +205,8 @@ func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVer
continue
}
for _, v := range f.Values {
- if _, ok := recognizedEngines[aws.StringValue(v)]; !ok {
- return trace.Errorf("unrecognized engine name %q", aws.StringValue(v))
+ if _, ok := recognizedEngines[v]; !ok {
+ return trace.Errorf("unrecognized engine name %q", v)
}
}
}
@@ -316,11 +214,11 @@ func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVer
}
// applyInstanceFilters filters RDS DBInstances using the provided RDS filters.
-func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.DBInstance, error) {
+func applyInstanceFilters(in []rdstypes.DBInstance, filters []rdstypes.Filter) ([]rdstypes.DBInstance, error) {
if len(filters) == 0 {
return in, nil
}
- var out []*rds.DBInstance
+ var out []rdstypes.DBInstance
efs := engineFilterSet(filters)
clusterIDs := clusterIdentifierFilterSet(filters)
for _, instance := range in {
@@ -336,11 +234,11 @@ func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.D
}
// applyClusterFilters filters RDS DBClusters using the provided RDS filters.
-func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBCluster, error) {
+func applyClusterFilters(in []rdstypes.DBCluster, filters []rdstypes.Filter) ([]rdstypes.DBCluster, error) {
if len(filters) == 0 {
return in, nil
}
- var out []*rds.DBCluster
+ var out []rdstypes.DBCluster
efs := engineFilterSet(filters)
for _, cluster := range in {
if clusterEngineMatches(cluster, efs) {
@@ -351,59 +249,59 @@ func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBC
}
// engineFilterSet builds a string set of engine names from a list of RDS filters.
-func engineFilterSet(filters []*rds.Filter) map[string]struct{} {
+func engineFilterSet(filters []rdstypes.Filter) map[string]struct{} {
return filterValues(filters, "engine")
}
// clusterIdentifierFilterSet builds a string set of ClusterIDs from a list of RDS filters.
-func clusterIdentifierFilterSet(filters []*rds.Filter) map[string]struct{} {
+func clusterIdentifierFilterSet(filters []rdstypes.Filter) map[string]struct{} {
return filterValues(filters, "db-cluster-id")
}
-func filterValues(filters []*rds.Filter, filterKey string) map[string]struct{} {
+func filterValues(filters []rdstypes.Filter, filterKey string) map[string]struct{} {
out := make(map[string]struct{})
for _, f := range filters {
if aws.StringValue(f.Name) != filterKey {
continue
}
for _, v := range f.Values {
- out[aws.StringValue(v)] = struct{}{}
+ out[v] = struct{}{}
}
}
return out
}
// instanceEngineMatches returns whether an RDS DBInstance engine matches any engine name in a filter set.
-func instanceEngineMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool {
+func instanceEngineMatches(instance rdstypes.DBInstance, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(instance.Engine)]
return ok
}
// instanceClusterIDMatches returns whether an RDS DBInstance ClusterID matches any ClusterID in a filter set.
-func instanceClusterIDMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool {
+func instanceClusterIDMatches(instance rdstypes.DBInstance, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(instance.DBClusterIdentifier)]
return ok
}
// clusterEngineMatches returns whether an RDS DBCluster engine matches any engine name in a filter set.
-func clusterEngineMatches(cluster *rds.DBCluster, filterSet map[string]struct{}) bool {
+func clusterEngineMatches(cluster rdstypes.DBCluster, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(cluster.Engine)]
return ok
}
-// RDSInstance returns a sample rds.DBInstance.
-func RDSInstance(name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) *rds.DBInstance {
- instance := &rds.DBInstance{
+// RDSInstance returns a sample rdstypes.DBInstance.
+func RDSInstance(name, region string, labels map[string]string, opts ...func(*rdstypes.DBInstance)) *rdstypes.DBInstance {
+ instance := &rdstypes.DBInstance{
DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)),
DBInstanceIdentifier: aws.String(name),
DbiResourceId: aws.String(uuid.New().String()),
Engine: aws.String("postgres"),
DBInstanceStatus: aws.String("available"),
- Endpoint: &rds.Endpoint{
+ Endpoint: &rdstypes.Endpoint{
Address: aws.String(fmt.Sprintf("%v.aabbccdd.%v.rds.amazonaws.com", name, region)),
- Port: aws.Int64(5432),
+ Port: aws.Int32(5432),
},
- TagList: libcloudaws.LabelsToTags[rds.Tag](labels),
+ TagList: awstesthelpers.LabelsToRDSTags(labels),
}
for _, opt := range opts {
opt(instance)
@@ -411,9 +309,9 @@ func RDSInstance(name, region string, labels map[string]string, opts ...func(*rd
return instance
}
-// RDSCluster returns a sample rds.DBCluster.
-func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) *rds.DBCluster {
- cluster := &rds.DBCluster{
+// RDSCluster returns a sample rdstypes.DBCluster.
+func RDSCluster(name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) *rdstypes.DBCluster {
+ cluster := &rdstypes.DBCluster{
DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)),
DBClusterIdentifier: aws.String(name),
DbClusterResourceId: aws.String(uuid.New().String()),
@@ -422,9 +320,9 @@ func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds
Status: aws.String("available"),
Endpoint: aws.String(fmt.Sprintf("%v.cluster-aabbccdd.%v.rds.amazonaws.com", name, region)),
ReaderEndpoint: aws.String(fmt.Sprintf("%v.cluster-ro-aabbccdd.%v.rds.amazonaws.com", name, region)),
- Port: aws.Int64(3306),
- TagList: libcloudaws.LabelsToTags[rds.Tag](labels),
- DBClusterMembers: []*rds.DBClusterMember{{
+ Port: aws.Int32(3306),
+ TagList: awstesthelpers.LabelsToRDSTags(labels),
+ DBClusterMembers: []rdstypes.DBClusterMember{{
IsClusterWriter: aws.Bool(true), // One writer by default.
}},
}
@@ -434,49 +332,49 @@ func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds
return cluster
}
-func WithRDSClusterReader(cluster *rds.DBCluster) {
- cluster.DBClusterMembers = append(cluster.DBClusterMembers, &rds.DBClusterMember{
+func WithRDSClusterReader(cluster *rdstypes.DBCluster) {
+ cluster.DBClusterMembers = append(cluster.DBClusterMembers, rdstypes.DBClusterMember{
IsClusterWriter: aws.Bool(false), // Add reader.
})
}
-func WithRDSClusterCustomEndpoint(name string) func(*rds.DBCluster) {
- return func(cluster *rds.DBCluster) {
+func WithRDSClusterCustomEndpoint(name string) func(*rdstypes.DBCluster) {
+ return func(cluster *rdstypes.DBCluster) {
parsed, _ := arn.Parse(aws.StringValue(cluster.DBClusterArn))
- cluster.CustomEndpoints = append(cluster.CustomEndpoints, aws.String(
+ cluster.CustomEndpoints = append(cluster.CustomEndpoints,
fmt.Sprintf("%v.cluster-custom-aabbccdd.%v.rds.amazonaws.com", name, parsed.Region),
- ))
+ )
}
}
-// RDSProxy returns a sample rds.DBProxy.
-func RDSProxy(name, region, vpcID string) *rds.DBProxy {
- return &rds.DBProxy{
+// RDSProxy returns a sample rdstypes.DBProxy.
+func RDSProxy(name, region, vpcID string) *rdstypes.DBProxy {
+ return &rdstypes.DBProxy{
DBProxyArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:123456789012:db-proxy:prx-%s", region, name)),
DBProxyName: aws.String(name),
- EngineFamily: aws.String(rds.EngineFamilyMysql),
+ EngineFamily: aws.String(string(rdstypes.EngineFamilyMysql)),
Endpoint: aws.String(fmt.Sprintf("%s.proxy-aabbccdd.%s.rds.amazonaws.com", name, region)),
VpcId: aws.String(vpcID),
RequireTLS: aws.Bool(true),
- Status: aws.String("available"),
+ Status: "available",
}
}
-// RDSProxyCustomEndpoint returns a sample rds.DBProxyEndpoint.
-func RDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, name, region string) *rds.DBProxyEndpoint {
- return &rds.DBProxyEndpoint{
+// RDSProxyCustomEndpoint returns a sample rdstypes.DBProxyEndpoint.
+func RDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, name, region string) *rdstypes.DBProxyEndpoint {
+ return &rdstypes.DBProxyEndpoint{
Endpoint: aws.String(fmt.Sprintf("%s.endpoint.proxy-aabbccdd.%s.rds.amazonaws.com", name, region)),
DBProxyEndpointName: aws.String(name),
DBProxyName: rdsProxy.DBProxyName,
DBProxyEndpointArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db-proxy-endpoint:prx-endpoint-%v", region, name)),
- TargetRole: aws.String(rds.DBProxyEndpointTargetRoleReadOnly),
- Status: aws.String("available"),
+ TargetRole: rdstypes.DBProxyEndpointTargetRoleReadOnly,
+ Status: "available",
}
}
-// DocumentDBCluster returns a sample rds.DBCluster for DocumentDB.
-func DocumentDBCluster(name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) *rds.DBCluster {
- cluster := &rds.DBCluster{
+// DocumentDBCluster returns a sample rdstypes.DBCluster for DocumentDB.
+func DocumentDBCluster(name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) *rdstypes.DBCluster {
+ cluster := &rdstypes.DBCluster{
DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)),
DBClusterIdentifier: aws.String(name),
DbClusterResourceId: aws.String(uuid.New().String()),
@@ -485,9 +383,9 @@ func DocumentDBCluster(name, region string, labels map[string]string, opts ...fu
Status: aws.String("available"),
Endpoint: aws.String(fmt.Sprintf("%v.cluster-aabbccdd.%v.docdb.amazonaws.com", name, region)),
ReaderEndpoint: aws.String(fmt.Sprintf("%v.cluster-ro-aabbccdd.%v.docdb.amazonaws.com", name, region)),
- Port: aws.Int64(27017),
- TagList: libcloudaws.LabelsToTags[rds.Tag](labels),
- DBClusterMembers: []*rds.DBClusterMember{{
+ Port: aws.Int32(27017),
+ TagList: awstesthelpers.LabelsToRDSTags(labels),
+ DBClusterMembers: []rdstypes.DBClusterMember{{
IsClusterWriter: aws.Bool(true), // One writer by default.
}},
}
@@ -497,6 +395,6 @@ func DocumentDBCluster(name, region string, labels map[string]string, opts ...fu
return cluster
}
-func WithDocumentDBClusterReader(cluster *rds.DBCluster) {
+func WithDocumentDBClusterReader(cluster *rdstypes.DBCluster) {
WithRDSClusterReader(cluster)
}
diff --git a/lib/cloud/mocks/aws_sts.go b/lib/cloud/mocks/aws_sts.go
index 178a1259669a4..cf117788e696f 100644
--- a/lib/cloud/mocks/aws_sts.go
+++ b/lib/cloud/mocks/aws_sts.go
@@ -54,6 +54,12 @@ type STSClient struct {
recordFn func(roleARN, externalID string)
}
+func (m *STSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) {
+ return &sts.GetCallerIdentityOutput{
+ Arn: aws.String(m.ARN),
+ }, nil
+}
+
func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
m.record(aws.ToString(in.RoleArn), "")
expiry := time.Now().Add(60 * time.Minute)
diff --git a/lib/events/complete.go b/lib/events/complete.go
index de9022391a533..881e02e80ce62 100644
--- a/lib/events/complete.go
+++ b/lib/events/complete.go
@@ -271,7 +271,7 @@ func (u *UploadCompleter) CheckUploads(ctx context.Context) error {
continue
}
- log.DebugContext(ctx, "foud upload with parts", "part_count", len(parts))
+ log.DebugContext(ctx, "found upload with parts", "part_count", len(parts))
if err := u.cfg.Uploader.CompleteUpload(ctx, upload, parts); trace.IsNotFound(err) {
log.DebugContext(ctx, "Upload not found, moving on to next upload", "error", err)
diff --git a/lib/integrations/awsoidc/eks_enroll_clusters.go b/lib/integrations/awsoidc/eks_enroll_clusters.go
index dbeb6f2385484..d61b062cccfdb 100644
--- a/lib/integrations/awsoidc/eks_enroll_clusters.go
+++ b/lib/integrations/awsoidc/eks_enroll_clusters.go
@@ -74,8 +74,10 @@ const (
concurrentEKSEnrollingLimit = 5
)
-var agentRepoURL = url.URL{Scheme: "https", Host: "charts.releases.teleport.dev"}
-var agentStagingRepoURL = url.URL{Scheme: "https", Host: "charts.releases.development.teleport.dev"}
+var (
+ agentRepoURL = url.URL{Scheme: "https", Host: "charts.releases.teleport.dev"}
+ agentStagingRepoURL = url.URL{Scheme: "https", Host: "charts.releases.development.teleport.dev"}
+)
// EnrollEKSClusterResult contains result for a single EKS cluster enrollment, if it was successful 'Error' will be nil
// otherwise it will contain an error happened during enrollment.
@@ -462,7 +464,6 @@ func enrollEKSCluster(ctx context.Context, log *slog.Logger, clock clockwork.Clo
return "",
issueTypeFromCheckAgentInstalledError(err),
trace.Wrap(err, "could not check if teleport-kube-agent is already installed.")
-
} else if alreadyInstalled {
return "",
// When using EKS Auto Discovery, after the Kube Agent connects to the Teleport cluster, it is ignored in next discovery iterations.
@@ -708,7 +709,8 @@ func installKubeAgent(ctx context.Context, cfg installKubeAgentParams) error {
if cfg.req.IsCloud && cfg.req.EnableAutoUpgrades {
vals["updater"] = map[string]any{"enabled": true, "releaseChannel": "stable/cloud"}
- vals["highAvailability"] = map[string]any{"replicaCount": 2,
+ vals["highAvailability"] = map[string]any{
+ "replicaCount": 2,
"podDisruptionBudget": map[string]any{"enabled": true, "minAvailable": 1},
}
}
@@ -716,11 +718,10 @@ func installKubeAgent(ctx context.Context, cfg installKubeAgentParams) error {
vals["enterprise"] = true
}
- eksTags := make(map[string]*string, len(cfg.eksCluster.Tags))
- for k, v := range cfg.eksCluster.Tags {
- eksTags[k] = aws.String(v)
- }
- eksTags[types.OriginLabel] = aws.String(types.OriginCloud)
+ eksTags := make(map[string]string, len(cfg.eksCluster.Tags))
+ maps.Copy(eksTags, cfg.eksCluster.Tags)
+ eksTags[types.OriginLabel] = types.OriginCloud
+
kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(cfg.eksCluster.Name), aws.ToString(cfg.eksCluster.Arn), eksTags)
if err != nil {
return trace.Wrap(err)
diff --git a/lib/kube/proxy/cluster_details.go b/lib/kube/proxy/cluster_details.go
index 1a66ce0562978..e1dbc45fca281 100644
--- a/lib/kube/proxy/cluster_details.go
+++ b/lib/kube/proxy/cluster_details.go
@@ -26,8 +26,8 @@ import (
"sync"
"time"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/eks"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"k8s.io/apimachinery/pkg/runtime/schema"
@@ -39,6 +39,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/cloud"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/cloud/gcp"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
@@ -50,6 +51,7 @@ import (
// kubeDetails contain the cluster-related details including authentication.
type kubeDetails struct {
kubeCreds
+
// dynamicLabels is the dynamic labels executor for this cluster.
dynamicLabels *labels.Dynamic
// kubeCluster is the dynamic kube_cluster or a static generated from kubeconfig and that only has the name populated.
@@ -86,6 +88,8 @@ type kubeDetails struct {
type clusterDetailsConfig struct {
// cloudClients is the cloud clients to use for dynamic clusters.
cloudClients cloud.Clients
+ // awsCloudClients provides AWS SDK clients.
+ awsCloudClients AWSClientGetter
// kubeCreds is the credentials to use for the cluster.
kubeCreds kubeCreds
// cluster is the cluster to create a proxied cluster for.
@@ -103,8 +107,10 @@ type clusterDetailsConfig struct {
component KubeServiceType
}
-const defaultRefreshPeriod = 5 * time.Minute
-const backoffRefreshStep = 10 * time.Second
+const (
+ defaultRefreshPeriod = 5 * time.Minute
+ backoffRefreshStep = 10 * time.Second
+)
// newClusterDetails creates a proxied kubeDetails structure given a dynamic cluster.
func newClusterDetails(ctx context.Context, cfg clusterDetailsConfig) (_ *kubeDetails, err error) {
@@ -263,14 +269,20 @@ func (k *kubeDetails) getObjectGVK(resource apiResource) *schema.GroupVersionKin
// getKubeClusterCredentials generates kube credentials for dynamic clusters.
func getKubeClusterCredentials(ctx context.Context, cfg clusterDetailsConfig) (kubeCreds, error) {
- dynCredsCfg := dynamicCredsConfig{kubeCluster: cfg.cluster, log: cfg.log, checker: cfg.checker, resourceMatchers: cfg.resourceMatchers, clock: cfg.clock, component: cfg.component}
- switch {
+ switch dynCredsCfg := (dynamicCredsConfig{
+ kubeCluster: cfg.cluster,
+ log: cfg.log,
+ checker: cfg.checker,
+ resourceMatchers: cfg.resourceMatchers,
+ clock: cfg.clock,
+ component: cfg.component,
+ }); {
case cfg.cluster.IsKubeconfig():
return getStaticCredentialsFromKubeconfig(ctx, cfg.component, cfg.cluster, cfg.log, cfg.checker)
case cfg.cluster.IsAzure():
return getAzureCredentials(ctx, cfg.cloudClients, dynCredsCfg)
case cfg.cluster.IsAWS():
- return getAWSCredentials(ctx, cfg.cloudClients, dynCredsCfg)
+ return getAWSCredentials(ctx, cfg.awsCloudClients, dynCredsCfg)
case cfg.cluster.IsGCP():
return getGCPCredentials(ctx, cfg.cloudClients, dynCredsCfg)
default:
@@ -308,7 +320,7 @@ func azureRestConfigClient(cloudClients cloud.Clients) dynamicCredsClient {
}
// getAWSCredentials creates a dynamicKubeCreds that generates and updates the access credentials to a EKS kubernetes cluster.
-func getAWSCredentials(ctx context.Context, cloudClients cloud.Clients, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) {
+func getAWSCredentials(ctx context.Context, cloudClients AWSClientGetter, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) {
// create a client that returns the credentials for kubeCluster
cfg.client = getAWSClientRestConfig(cloudClients, cfg.clock, cfg.resourceMatchers)
creds, err := newDynamicKubeCreds(ctx, cfg)
@@ -328,51 +340,66 @@ func getAWSResourceMatcherToCluster(kubeCluster types.KubeCluster, resourceMatch
if match, _, _ := services.MatchLabels(matcher.Labels, kubeCluster.GetAllLabels()); !match {
continue
}
-
- return &(matcher.AWS)
+ return &matcher.AWS
}
return nil
}
+// STSPresignClient is the subset of the STS presign interface we use in fetchers.
+type STSPresignClient = kubeutils.STSPresignClient
+
+// EKSClient is the subset of the EKS Client interface we use.
+type EKSClient interface {
+ eks.DescribeClusterAPIClient
+}
+
+// AWSClientGetter is an interface for getting an EKS client and an STS client.
+type AWSClientGetter interface {
+ awsconfig.Provider
+ // GetAWSEKSClient returns AWS EKS client for the specified config.
+ GetAWSEKSClient(aws.Config) EKSClient
+ // GetAWSSTSPresignClient returns AWS STS presign client for the specified config.
+ GetAWSSTSPresignClient(aws.Config) STSPresignClient
+}
+
// getAWSClientRestConfig creates a dynamicCredsClient that generates returns credentials to EKS clusters.
-func getAWSClientRestConfig(cloudClients cloud.Clients, clock clockwork.Clock, resourceMatchers []services.ResourceMatcher) dynamicCredsClient {
+func getAWSClientRestConfig(cloudClients AWSClientGetter, clock clockwork.Clock, resourceMatchers []services.ResourceMatcher) dynamicCredsClient {
return func(ctx context.Context, cluster types.KubeCluster) (*rest.Config, time.Time, error) {
region := cluster.GetAWSConfig().Region
- opts := []cloud.AWSOptionsFn{
- cloud.WithAmbientCredentials(),
- cloud.WithoutSessionCache(),
+ opts := []awsconfig.OptionsFn{
+ awsconfig.WithAmbientCredentials(),
}
if awsAssume := getAWSResourceMatcherToCluster(cluster, resourceMatchers); awsAssume != nil {
- opts = append(opts, cloud.WithAssumeRole(awsAssume.AssumeRoleARN, awsAssume.ExternalID))
+ opts = append(opts, awsconfig.WithAssumeRole(awsAssume.AssumeRoleARN, awsAssume.ExternalID))
}
- regionalClient, err := cloudClients.GetAWSEKSClient(ctx, region, opts...)
+
+ cfg, err := cloudClients.GetConfig(ctx, region, opts...)
if err != nil {
return nil, time.Time{}, trace.Wrap(err)
}
- eksCfg, err := regionalClient.DescribeClusterWithContext(ctx, &eks.DescribeClusterInput{
+ regionalClient := cloudClients.GetAWSEKSClient(cfg)
+
+ eksCfg, err := regionalClient.DescribeCluster(ctx, &eks.DescribeClusterInput{
Name: aws.String(cluster.GetAWSConfig().Name),
})
if err != nil {
return nil, time.Time{}, trace.Wrap(err)
}
- ca, err := base64.StdEncoding.DecodeString(aws.StringValue(eksCfg.Cluster.CertificateAuthority.Data))
+ ca, err := base64.StdEncoding.DecodeString(aws.ToString(eksCfg.Cluster.CertificateAuthority.Data))
if err != nil {
return nil, time.Time{}, trace.Wrap(err)
}
- apiEndpoint := aws.StringValue(eksCfg.Cluster.Endpoint)
+ apiEndpoint := aws.ToString(eksCfg.Cluster.Endpoint)
if len(apiEndpoint) == 0 {
return nil, time.Time{}, trace.BadParameter("invalid api endpoint for cluster %q", cluster.GetAWSConfig().Name)
}
- stsClient, err := cloudClients.GetAWSSTSClient(ctx, region, opts...)
- if err != nil {
- return nil, time.Time{}, trace.Wrap(err)
- }
+ stsPresignClient := cloudClients.GetAWSSTSPresignClient(cfg)
- token, exp, err := kubeutils.GenAWSEKSToken(stsClient, cluster.GetAWSConfig().Name, clock)
+ token, exp, err := kubeutils.GenAWSEKSToken(ctx, stsPresignClient, cluster.GetAWSConfig().Name, clock)
if err != nil {
return nil, time.Time{}, trace.Wrap(err)
}
diff --git a/lib/kube/proxy/kube_creds_test.go b/lib/kube/proxy/kube_creds_test.go
index ca4f1bd4b58e0..ca2f537e6de05 100644
--- a/lib/kube/proxy/kube_creds_test.go
+++ b/lib/kube/proxy/kube_creds_test.go
@@ -26,8 +26,11 @@ import (
"testing"
"time"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/eks"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
@@ -41,10 +44,65 @@ import (
"github.com/gravitational/teleport/lib/cloud/gcp"
"github.com/gravitational/teleport/lib/cloud/mocks"
"github.com/gravitational/teleport/lib/fixtures"
+ kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
)
+type mockEKSClientGetter struct {
+ mocks.AWSConfigProvider
+ stsPresignClient *mockSTSPresignAPI
+ eksClient *mockEKSAPI
+}
+
+func (e *mockEKSClientGetter) GetAWSEKSClient(aws.Config) EKSClient {
+ return e.eksClient
+}
+
+func (e *mockEKSClientGetter) GetAWSSTSPresignClient(aws.Config) kubeutils.STSPresignClient {
+ return e.stsPresignClient
+}
+
+type mockSTSPresignAPI struct {
+ url *url.URL
+}
+
+func (a *mockSTSPresignAPI) PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) {
+ return &v4.PresignedHTTPRequest{URL: a.url.String()}, nil
+}
+
+type mockEKSAPI struct {
+ EKSClient
+
+ notify chan struct{}
+ clusters []*ekstypes.Cluster
+}
+
+func (m *mockEKSAPI) ListClusters(ctx context.Context, req *eks.ListClustersInput, _ ...func(*eks.Options)) (*eks.ListClustersOutput, error) {
+ defer func() { m.notify <- struct{}{} }()
+
+ var names []string
+ for _, cluster := range m.clusters {
+ names = append(names, aws.ToString(cluster.Name))
+ }
+ return &eks.ListClustersOutput{
+ Clusters: names,
+ }, nil
+}
+
+func (m *mockEKSAPI) DescribeCluster(_ context.Context, req *eks.DescribeClusterInput, _ ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) {
+ defer func() { m.notify <- struct{}{} }()
+
+ for _, cluster := range m.clusters {
+ if aws.ToString(cluster.Name) == aws.ToString(req.Name) {
+ return &eks.DescribeClusterOutput{
+ Cluster: cluster,
+ }, nil
+ }
+ }
+ return nil, trace.NotFound("cluster %q not found", aws.ToString(req.Name))
+}
+
// Test_DynamicKubeCreds tests the dynamic kube credrentials generator for
// AWS, GCP, and Azure clusters accessed using their respective IAM credentials.
// This test mocks the cloud provider clients and the STS client to generate
@@ -99,32 +157,37 @@ func Test_DynamicKubeCreds(t *testing.T) {
)
require.NoError(t, err)
- // mock sts client
+ // Mock sts client.
u := &url.URL{
Scheme: "https",
Host: "sts.amazonaws.com",
Path: "/?Action=GetCallerIdentity&Version=2011-06-15",
}
- sts := &mocks.STSClientV1{
- // u is used to presign the request
- // here we just verify the pre-signed request includes this url.
- URL: u,
- }
- // mock clients
- cloudclients := &cloud.TestCloudClients{
- STS: sts,
- EKS: &mocks.EKSMock{
- Notify: notify,
- Clusters: []*eks.Cluster{
+ // EKS clients.
+ eksClients := &mockEKSClientGetter{
+ AWSConfigProvider: mocks.AWSConfigProvider{
+ STSClient: &mocks.STSClient{},
+ },
+ stsPresignClient: &mockSTSPresignAPI{
+ // u is used to presign the request
+ // here we just verify the pre-signed request includes this url.
+ url: u,
+ },
+ eksClient: &mockEKSAPI{
+ notify: notify,
+ clusters: []*ekstypes.Cluster{
{
Endpoint: aws.String("https://api.eks.us-west-2.amazonaws.com"),
Name: aws.String(awsKube.GetAWSConfig().Name),
- CertificateAuthority: &eks.Certificate{
+ CertificateAuthority: &ekstypes.Certificate{
Data: aws.String(base64.RawStdEncoding.EncodeToString([]byte(fixtures.TLSCACertPEM))),
},
},
},
},
+ }
+ // Mock clients.
+ cloudclients := &cloud.TestCloudClients{
GCPGKE: &mocks.GKEMock{
Notify: notify,
Clock: fakeClock,
@@ -204,7 +267,7 @@ func Test_DynamicKubeCreds(t *testing.T) {
name: "aws eks cluster without assume role",
args: args{
cluster: awsKube,
- client: getAWSClientRestConfig(cloudclients, fakeClock, nil),
+ client: getAWSClientRestConfig(eksClients, fakeClock, nil),
validateBearerToken: validateEKSToken,
},
wantAddr: "api.eks.us-west-2.amazonaws.com:443",
@@ -213,7 +276,7 @@ func Test_DynamicKubeCreds(t *testing.T) {
name: "aws eks cluster with unmatched assume role",
args: args{
cluster: awsKube,
- client: getAWSClientRestConfig(cloudclients, fakeClock, []services.ResourceMatcher{
+ client: getAWSClientRestConfig(eksClients, fakeClock, []services.ResourceMatcher{
{
Labels: types.Labels{
"rand": []string{"value"},
@@ -233,7 +296,7 @@ func Test_DynamicKubeCreds(t *testing.T) {
args: args{
cluster: awsKube,
client: getAWSClientRestConfig(
- cloudclients,
+ eksClients,
fakeClock,
[]services.ResourceMatcher{
{
@@ -331,6 +394,7 @@ func Test_DynamicKubeCreds(t *testing.T) {
}
require.NoError(t, got.close())
+ sts := eksClients.AWSConfigProvider.STSClient
require.Equal(t, tt.wantAssumedRole, apiutils.Deduplicate(sts.GetAssumedRoleARNs()))
require.Equal(t, tt.wantExternalIds, apiutils.Deduplicate(sts.GetAssumedRoleExternalIDs()))
sts.ResetAssumeRoleHistory()
diff --git a/lib/kube/proxy/server.go b/lib/kube/proxy/server.go
index 6ac466746b51f..f153039d60749 100644
--- a/lib/kube/proxy/server.go
+++ b/lib/kube/proxy/server.go
@@ -28,6 +28,9 @@ import (
"sync"
"time"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"golang.org/x/net/http2"
@@ -38,6 +41,7 @@ import (
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cloud"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/inventory"
"github.com/gravitational/teleport/lib/labels"
@@ -74,6 +78,7 @@ type TLSServerConfig struct {
OnReconcile func(types.KubeClusters)
// CloudClients is a set of cloud clients that Teleport supports.
CloudClients cloud.Clients
+ awsClients *awsClientsGetter
// StaticLabels is a map of static labels associated with this service.
// Each cluster advertised by this kubernetes_service will include these static labels.
// If the service and a cluster define labels with the same key,
@@ -106,6 +111,21 @@ type TLSServerConfig struct {
InventoryHandle inventory.DownstreamHandle
}
+type awsClientsGetter struct{}
+
+func (f *awsClientsGetter) GetConfig(ctx context.Context, region string, optFns ...awsconfig.OptionsFn) (aws.Config, error) {
+ return awsconfig.GetConfig(ctx, region, optFns...)
+}
+
+func (f *awsClientsGetter) GetAWSEKSClient(cfg aws.Config) EKSClient {
+ return eks.NewFromConfig(cfg)
+}
+
+func (f *awsClientsGetter) GetAWSSTSPresignClient(cfg aws.Config) STSPresignClient {
+ stsClient := sts.NewFromConfig(cfg)
+ return sts.NewPresignClient(stsClient)
+}
+
// CheckAndSetDefaults checks and sets default values
func (c *TLSServerConfig) CheckAndSetDefaults() error {
if err := c.ForwarderConfig.CheckAndSetDefaults(); err != nil {
@@ -142,6 +162,9 @@ func (c *TLSServerConfig) CheckAndSetDefaults() error {
}
c.CloudClients = cloudClients
}
+ if c.awsClients == nil {
+ c.awsClients = &awsClientsGetter{}
+ }
if c.ConnectedProxyGetter == nil {
c.ConnectedProxyGetter = reversetunnel.NewConnectedProxyGetter()
}
diff --git a/lib/kube/proxy/watcher.go b/lib/kube/proxy/watcher.go
index 56bea639d5260..fd83ddfd1ad60 100644
--- a/lib/kube/proxy/watcher.go
+++ b/lib/kube/proxy/watcher.go
@@ -174,6 +174,7 @@ func (m *monitoredKubeClusters) get() map[string]types.KubeCluster {
func (s *TLSServer) buildClusterDetailsConfigForCluster(cluster types.KubeCluster) clusterDetailsConfig {
return clusterDetailsConfig{
cloudClients: s.CloudClients,
+ awsCloudClients: s.awsClients,
cluster: cluster,
log: s.log,
checker: s.CheckImpersonationPermissions,
diff --git a/lib/kube/utils/eks_token_signed.go b/lib/kube/utils/eks_token_signed.go
index 4431cf93dad79..1a1840af888ef 100644
--- a/lib/kube/utils/eks_token_signed.go
+++ b/lib/kube/utils/eks_token_signed.go
@@ -19,44 +19,64 @@
package utils
import (
+ "context"
"encoding/base64"
"time"
- "github.com/aws/aws-sdk-go/service/sts"
- "github.com/aws/aws-sdk-go/service/sts/stsiface"
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
)
+// STSPresignClient is the subset of the STS presign client we need to generate EKS tokens.
+type STSPresignClient interface {
+ PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error)
+}
+
// GenAWSEKSToken creates an AWS token to access EKS clusters.
// Logic from https://github.com/aws/aws-cli/blob/6c0d168f0b44136fc6175c57c090d4b115437ad1/awscli/customizations/eks/get_token.py#L211-L229
-func GenAWSEKSToken(stsClient stsiface.STSAPI, clusterID string, clock clockwork.Clock) (string, time.Time, error) {
+// TODO(@creack): Consolidate with https://github.com/gravitational/teleport/blob/d37da511c944825a47155421bf278777238eecc0/lib/integrations/awsoidc/eks_enroll_clusters.go#L341-L372
+func GenAWSEKSToken(ctx context.Context, stsClient STSPresignClient, clusterID string, clock clockwork.Clock) (string, time.Time, error) {
const (
- // The sts GetCallerIdentity request is valid for 15 minutes regardless of this parameters value after it has been
- // signed.
- requestPresignParam = 60
// The actual token expiration (presigned STS urls are valid for 15 minutes after timestamp in x-amz-date).
+ expireHeader = "X-Amz-Expires"
+ expireValue = "60"
presignedURLExpiration = 15 * time.Minute
v1Prefix = "k8s-aws-v1."
clusterIDHeader = "x-k8s-aws-id"
)
- // generate an sts:GetCallerIdentity request and add our custom cluster ID header
- request, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
- request.HTTPRequest.Header.Add(clusterIDHeader, clusterID)
-
// Sign the request. The expires parameter (sets the x-amz-expires header) is
// currently ignored by STS, and the token expires 15 minutes after the x-amz-date
// timestamp regardless. We set it to 60 seconds for backwards compatibility (the
// parameter is a required argument to Presign(), and authenticators 0.3.0 and older are expecting a value between
// 0 and 60 on the server side).
// https://github.com/aws/aws-sdk-go/issues/2167
- presignedURLString, err := request.Presign(requestPresignParam)
+ presignedReq, err := stsClient.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) {
+ po.ClientOptions = append(po.ClientOptions, sts.WithAPIOptions(func(stack *middleware.Stack) error {
+ return stack.Build.Add(middleware.BuildMiddlewareFunc("AddEKSId", func(
+ ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
+ ) (middleware.BuildOutput, middleware.Metadata, error) {
+ switch req := in.Request.(type) {
+ case *smithyhttp.Request:
+ query := req.URL.Query()
+ query.Add(expireHeader, expireValue)
+ req.URL.RawQuery = query.Encode()
+
+ req.Header.Add(clusterIDHeader, clusterID)
+ }
+ return next.HandleBuild(ctx, in)
+ }), middleware.Before)
+ }))
+ })
if err != nil {
return "", time.Time{}, trace.Wrap(err)
}
- // Set token expiration to 1 minute before the presigned URL expires for some cushion
+ // Set token expiration to 1 minute before the presigned URL expires for some cushion.
tokenExpiration := clock.Now().Add(presignedURLExpiration - 1*time.Minute)
- return v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration, nil
+ return v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedReq.URL)), tokenExpiration, nil
}
diff --git a/lib/service/service.go b/lib/service/service.go
index 7638ee5e85caf..7fd997e7234f0 100644
--- a/lib/service/service.go
+++ b/lib/service/service.go
@@ -54,6 +54,7 @@ import (
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
+ "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/quic-go/quic-go"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
@@ -657,6 +658,15 @@ type TeleportProcess struct {
// resolver is used to identify the reverse tunnel address when connecting via
// the proxy.
resolver reversetunnelclient.Resolver
+
+ // metricRegistry is the prometheus metric registry for the process.
+ // Every teleport service that wants to register metrics should use this
+ // instead of the global prometheus.DefaultRegisterer to avoid registration
+ // conflicts.
+ //
+ // Both the metricsRegistry and the default global registry are gathered by
+ // Telepeort's metric service.
+ metricsRegistry *prometheus.Registry
}
// processIndex is an internal process index
@@ -1179,6 +1189,7 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) {
logger: cfg.Logger,
cloudLabels: cloudLabels,
TracingProvider: tracing.NoopProvider(),
+ metricsRegistry: cfg.MetricsRegistry,
}
process.registerExpectedServices(cfg)
@@ -3405,11 +3416,46 @@ func (process *TeleportProcess) initUploaderService() error {
return nil
}
+// promHTTPLogAdapter adapts a slog.Logger into a promhttp.Logger.
+type promHTTPLogAdapter struct {
+ ctx context.Context
+ *slog.Logger
+}
+
+// Println implements the promhttp.Logger interface.
+func (l promHTTPLogAdapter) Println(v ...interface{}) {
+ //nolint:sloglint // msg cannot be constant
+ l.ErrorContext(l.ctx, fmt.Sprint(v...))
+}
+
// initMetricsService starts the metrics service currently serving metrics for
// prometheus consumption
func (process *TeleportProcess) initMetricsService() error {
mux := http.NewServeMux()
- mux.Handle("/metrics", promhttp.Handler())
+
+ // We gather metrics both from the in-process registry (preferred metrics registration method)
+ // and the global registry (used by some Teleport services and many dependencies).
+ gatherers := prometheus.Gatherers{
+ process.metricsRegistry,
+ prometheus.DefaultGatherer,
+ }
+
+ metricsHandler := promhttp.InstrumentMetricHandler(
+ process.metricsRegistry, promhttp.HandlerFor(gatherers, promhttp.HandlerOpts{
+ // Errors can happen if metrics are registered with identical names in both the local and the global registry.
+ // In this case, we log the error but continue collecting metrics. The first collected metric will win
+ // (the one from the local metrics registry takes precedence).
+ // As we move more things to the local registry, especially in other tools like tbot, we will have less
+ // conflicts in tests.
+ ErrorHandling: promhttp.ContinueOnError,
+ ErrorLog: promHTTPLogAdapter{
+ ctx: process.ExitContext(),
+ Logger: process.logger.With(teleport.ComponentKey, teleport.ComponentMetrics),
+ },
+ }),
+ )
+
+ mux.Handle("/metrics", metricsHandler)
logger := process.logger.With(teleport.ComponentKey, teleport.Component(teleport.ComponentMetrics, process.id))
diff --git a/lib/service/service_test.go b/lib/service/service_test.go
index 52e59387ff580..4c08a87689145 100644
--- a/lib/service/service_test.go
+++ b/lib/service/service_test.go
@@ -23,9 +23,11 @@ import (
"crypto/tls"
"errors"
"fmt"
+ "io"
"log/slog"
"net"
"net/http"
+ "net/url"
"os"
"path/filepath"
"strings"
@@ -39,6 +41,8 @@ import (
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
@@ -1887,7 +1891,7 @@ func TestAgentRolloutController(t *testing.T) {
dataDir := makeTempDir(t)
cfg := servicecfg.MakeDefaultConfig()
- // We use a real clock because too many sevrices are using the clock and it's not possible to accurately wait for
+ // We use a real clock because too many services are using the clock and it's not possible to accurately wait for
// each one of them to reach the point where they wait for the clock to advance. If we add a WaitUntil(X waiters)
// check, this will break the next time we add a new waiter.
cfg.Clock = clockwork.NewRealClock()
@@ -1906,7 +1910,7 @@ func TestAgentRolloutController(t *testing.T) {
process, err := NewTeleport(cfg)
require.NoError(t, err)
- // Test setup: start the Teleport auth and wait for it to beocme ready
+ // Test setup: start the Teleport auth and wait for it to become ready
require.NoError(t, process.Start())
// Test setup: wait for every service to start
@@ -1949,6 +1953,84 @@ func TestAgentRolloutController(t *testing.T) {
}, 5*time.Second, 10*time.Millisecond)
}
+func TestMetricsService(t *testing.T) {
+ t.Parallel()
+ // Test setup: create a listener for the metrics server, get its file descriptor.
+
+ // Note: this code is copied from integrations/helpers/NewListenerOn() to avoid including helpers in a production
+ // build and avoid a cyclic dependency.
+ metricsListener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ assert.NoError(t, metricsListener.Close())
+ })
+ require.IsType(t, &net.TCPListener{}, metricsListener)
+ metricsListenerFile, err := metricsListener.(*net.TCPListener).File()
+ require.NoError(t, err)
+
+ // Test setup: create a new teleport process
+ dataDir := makeTempDir(t)
+ cfg := servicecfg.MakeDefaultConfig()
+ cfg.DataDir = dataDir
+ cfg.SetAuthServerAddress(utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"})
+ cfg.Auth.Enabled = true
+ cfg.Proxy.Enabled = false
+ cfg.SSH.Enabled = false
+ cfg.DebugService.Enabled = false
+ cfg.Auth.StorageConfig.Params["path"] = dataDir
+ cfg.Auth.ListenAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}
+ cfg.Metrics.Enabled = true
+
+ // Configure the metrics server to use the listener we previously created.
+ cfg.Metrics.ListenAddr = &utils.NetAddr{AddrNetwork: "tcp", Addr: metricsListener.Addr().String()}
+ cfg.FileDescriptors = []*servicecfg.FileDescriptor{
+ {Type: string(ListenerMetrics), Address: metricsListener.Addr().String(), File: metricsListenerFile},
+ }
+
+ // Create and start the Teleport service.
+ process, err := NewTeleport(cfg)
+ require.NoError(t, err)
+ require.NoError(t, process.Start())
+ t.Cleanup(func() {
+ assert.NoError(t, process.Close())
+ assert.NoError(t, process.Wait())
+ })
+
+ // Test setup: create our test metrics.
+ nonce := strings.ReplaceAll(uuid.NewString(), "-", "")
+ localMetric := prometheus.NewGauge(prometheus.GaugeOpts{
+ Namespace: "test",
+ Name: "local_metric_" + nonce,
+ })
+ globalMetric := prometheus.NewGauge(prometheus.GaugeOpts{
+ Namespace: "test",
+ Name: "global_metric_" + nonce,
+ })
+ require.NoError(t, process.metricsRegistry.Register(localMetric))
+ require.NoError(t, prometheus.Register(globalMetric))
+
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
+ t.Cleanup(cancel)
+ _, err = process.WaitForEvent(ctx, MetricsReady)
+ require.NoError(t, err)
+
+ // Test execution: get metrics and check the tests metrics are here.
+ metricsURL, err := url.Parse("http://" + metricsListener.Addr().String())
+ require.NoError(t, err)
+ metricsURL.Path = "/metrics"
+ resp, err := http.Get(metricsURL.String())
+ require.NoError(t, err)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ require.NoError(t, resp.Body.Close())
+
+ // Test validation: check that the metrics server served both the local and global registry.
+ require.Contains(t, string(body), "local_metric_"+nonce)
+ require.Contains(t, string(body), "global_metric_"+nonce)
+}
+
// makeTempDir makes a temp dir with a shorter name than t.TempDir() in order to
// avoid https://github.com/golang/go/issues/62614.
func makeTempDir(t *testing.T) string {
diff --git a/lib/service/servicecfg/config.go b/lib/service/servicecfg/config.go
index a89e79a8f6302..a89e29a2c7b54 100644
--- a/lib/service/servicecfg/config.go
+++ b/lib/service/servicecfg/config.go
@@ -34,6 +34,7 @@ import (
"github.com/ghodss/yaml"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
+ "github.com/prometheus/client_golang/prometheus"
"golang.org/x/crypto/ssh"
"github.com/gravitational/teleport"
@@ -264,6 +265,12 @@ type Config struct {
// protocol.
DatabaseREPLRegistry dbrepl.REPLRegistry
+ // MetricsRegistry is the prometheus metrics registry used by the Teleport process to register its metrics.
+ // As of today, not every Teleport metric is registered against this registry. Some Teleport services
+ // and Teleport dependencies are using the global registry.
+ // Both the MetricsRegistry and the default global registry are gathered by Teleport's metric service.
+ MetricsRegistry *prometheus.Registry
+
// token is either the token needed to join the auth server, or a path pointing to a file
// that contains the token
//
@@ -520,6 +527,10 @@ func ApplyDefaults(cfg *Config) {
cfg.LoggerLevel = new(slog.LevelVar)
}
+ if cfg.MetricsRegistry == nil {
+ cfg.MetricsRegistry = prometheus.NewRegistry()
+ }
+
// Remove insecure and (borderline insecure) cryptographic primitives from
// default configuration. These can still be added back in file configuration by
// users, but not supported by default by Teleport. See #1856 for more
diff --git a/lib/services/workload_identity.go b/lib/services/workload_identity.go
index 89b87ba0d2473..826ab6540b0e6 100644
--- a/lib/services/workload_identity.go
+++ b/lib/services/workload_identity.go
@@ -104,16 +104,8 @@ func ValidateWorkloadIdentity(s *workloadidentityv1pb.WorkloadIdentity) error {
if condition.Attribute == "" {
return trace.BadParameter("spec.rules.allow[%d].conditions[%d].attribute: must be non-empty", i, j)
}
- // Ensure exactly one operator is set.
- operatorsSet := 0
- if condition.Equals != "" {
- operatorsSet++
- }
- if operatorsSet == 0 || operatorsSet > 1 {
- return trace.BadParameter(
- "spec.rules.allow[%d].conditions[%d]: exactly one operator must be specified, found %d",
- i, j, operatorsSet,
- )
+ if condition.Operator == nil {
+ return trace.BadParameter("spec.rules.allow[%d].conditions[%d]: operator must be specified", i, j)
}
}
}
diff --git a/lib/services/workload_identity_test.go b/lib/services/workload_identity_test.go
index 429612ed48555..27d0e1ec0261b 100644
--- a/lib/services/workload_identity_test.go
+++ b/lib/services/workload_identity_test.go
@@ -92,7 +92,11 @@ func TestValidateWorkloadIdentity(t *testing.T) {
Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
{
Attribute: "example",
- Equals: "foo",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "foo",
+ },
+ },
},
},
},
@@ -180,7 +184,11 @@ func TestValidateWorkloadIdentity(t *testing.T) {
Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
{
Attribute: "",
- Equals: "foo",
+ Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{
+ Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{
+ Value: "foo",
+ },
+ },
},
},
},
@@ -218,7 +226,7 @@ func TestValidateWorkloadIdentity(t *testing.T) {
},
},
},
- requireErr: errContains("spec.rules.allow[0].conditions[0]: exactly one operator must be specified, found 0"),
+ requireErr: errContains("spec.rules.allow[0].conditions[0]: operator must be specified"),
},
}
diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go
index 46c6ca1a19f53..906acfd06c7cd 100644
--- a/lib/srv/db/access_test.go
+++ b/lib/srv/db/access_test.go
@@ -2491,7 +2491,6 @@ func (p *agentParams) setDefaults(c *testContext) {
if p.CloudClients == nil {
p.CloudClients = &clients.TestCloudClients{
STS: &mocks.STSClientV1{},
- RDS: &mocks.RDSMock{},
RedshiftServerless: &mocks.RedshiftServerlessMock{},
ElastiCache: p.ElastiCache,
MemoryDB: p.MemoryDB,
@@ -2501,7 +2500,7 @@ func (p *agentParams) setDefaults(c *testContext) {
}
}
if p.AWSConfigProvider == nil {
- p.AWSConfigProvider = &mocks.AWSConfigProvider{}
+ p.AWSConfigProvider = &mocks.AWSConfigProvider{Err: trace.AccessDenied("AWS SDK clients are disabled for tests by default")}
}
if p.DiscoveryResourceChecker == nil {
diff --git a/lib/srv/db/cloud/aws.go b/lib/srv/db/cloud/aws.go
index 8222599c318a7..c336cb43230dd 100644
--- a/lib/srv/db/cloud/aws.go
+++ b/lib/srv/db/cloud/aws.go
@@ -23,21 +23,24 @@ import (
"encoding/json"
"log/slog"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/gravitational/trace"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud"
awslib "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam"
)
// awsConfig is the config for the client that configures IAM for AWS databases.
type awsConfig struct {
+ // awsConfigProvider provides [aws.Config] for AWS SDK service clients.
+ awsConfigProvider awsconfig.Provider
// clients is an interface for creating AWS clients.
clients cloud.Clients
// identity is AWS identity this database agent is running as.
@@ -46,6 +49,9 @@ type awsConfig struct {
database types.Database
// policyName is the name of the inline policy for the identity.
policyName string
+ // awsClients is an internal-only AWS SDK client provider that is
+ // only set in tests.
+ awsClients awsClientProvider
}
// Check validates the config.
@@ -62,6 +68,12 @@ func (c *awsConfig) Check() error {
if c.policyName == "" {
return trace.BadParameter("missing parameter policy name")
}
+ if c.awsConfigProvider == nil {
+ return trace.BadParameter("missing parameter awsConfigProvider")
+ }
+ if c.awsClients == nil {
+ return trace.BadParameter("missing parameter awsClients")
+ }
return nil
}
@@ -75,7 +87,7 @@ func newAWS(ctx context.Context, config awsConfig) (*awsClient, error) {
teleport.ComponentKey, "aws",
"db", config.database.GetName(),
)
- dbConfigurator, err := getDBConfigurator(logger, config.clients, config.database)
+ dbConfigurator, err := getDBConfigurator(logger, config)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -102,10 +114,14 @@ type dbIAMAuthConfigurator interface {
}
// getDBConfigurator returns a database IAM Auth configurator.
-func getDBConfigurator(logger *slog.Logger, clients cloud.Clients, db types.Database) (dbIAMAuthConfigurator, error) {
- if db.IsRDS() {
+func getDBConfigurator(logger *slog.Logger, cfg awsConfig) (dbIAMAuthConfigurator, error) {
+ if cfg.database.IsRDS() {
// Only setting for RDS instances and Aurora clusters.
- return &rdsDBConfigurator{clients: clients, logger: logger}, nil
+ return &rdsDBConfigurator{
+ awsConfigProvider: cfg.awsConfigProvider,
+ logger: logger,
+ awsClients: cfg.awsClients,
+ }, nil
}
// IAM Auth for Redshift, ElastiCache, and RDS Proxy is always enabled.
return &nopDBConfigurator{}, nil
@@ -303,8 +319,9 @@ func (r *awsClient) detachIAMPolicy(ctx context.Context) error {
}
type rdsDBConfigurator struct {
- clients cloud.Clients
- logger *slog.Logger
+ awsConfigProvider awsconfig.Provider
+ logger *slog.Logger
+ awsClients awsClientProvider
}
// ensureIAMAuth enables RDS instance IAM auth if it isn't already enabled.
@@ -323,30 +340,34 @@ func (r *rdsDBConfigurator) ensureIAMAuth(ctx context.Context, db types.Database
func (r *rdsDBConfigurator) enableIAMAuth(ctx context.Context, db types.Database) error {
r.logger.DebugContext(ctx, "Enabling IAM auth for RDS")
meta := db.GetAWS()
- rdsClt, err := r.clients.GetAWSRDSClient(ctx, meta.Region,
- cloud.WithAssumeRoleFromAWSMeta(meta),
- cloud.WithAmbientCredentials(),
+ if meta.RDS.ClusterID == "" && meta.RDS.InstanceID == "" {
+ return trace.BadParameter("no RDS cluster ID or instance ID for %v", db)
+ }
+ awsCfg, err := r.awsConfigProvider.GetConfig(ctx, meta.Region,
+ awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID),
+ awsconfig.WithAmbientCredentials(),
)
if err != nil {
return trace.Wrap(err)
}
+ clt := r.awsClients.getRDSClient(awsCfg)
if meta.RDS.ClusterID != "" {
- _, err = rdsClt.ModifyDBClusterWithContext(ctx, &rds.ModifyDBClusterInput{
+ _, err = clt.ModifyDBCluster(ctx, &rds.ModifyDBClusterInput{
DBClusterIdentifier: aws.String(meta.RDS.ClusterID),
EnableIAMDatabaseAuthentication: aws.Bool(true),
ApplyImmediately: aws.Bool(true),
})
- return awslib.ConvertIAMError(err)
+ return awslib.ConvertRequestFailureErrorV2(err)
}
if meta.RDS.InstanceID != "" {
- _, err = rdsClt.ModifyDBInstanceWithContext(ctx, &rds.ModifyDBInstanceInput{
+ _, err = clt.ModifyDBInstance(ctx, &rds.ModifyDBInstanceInput{
DBInstanceIdentifier: aws.String(meta.RDS.InstanceID),
EnableIAMDatabaseAuthentication: aws.Bool(true),
ApplyImmediately: aws.Bool(true),
})
- return awslib.ConvertIAMError(err)
+ return awslib.ConvertRequestFailureErrorV2(err)
}
- return trace.BadParameter("no RDS cluster ID or instance ID for %v", db)
+ return nil
}
type nopDBConfigurator struct{}
diff --git a/lib/srv/db/cloud/iam.go b/lib/srv/db/cloud/iam.go
index aa1629157d78f..ef49e061e59b8 100644
--- a/lib/srv/db/cloud/iam.go
+++ b/lib/srv/db/cloud/iam.go
@@ -35,6 +35,7 @@ import (
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/cloud"
awslib "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db/common/iam"
)
@@ -45,6 +46,8 @@ type IAMConfig struct {
Clock clockwork.Clock
// AccessPoint is a caching client connected to the Auth Server.
AccessPoint authclient.DatabaseAccessPoint
+ // AWSConfigProvider provides [aws.Config] for AWS SDK service clients.
+ AWSConfigProvider awsconfig.Provider
// Clients is an interface for retrieving cloud clients.
Clients cloud.Clients
// HostID is the host identified where this agent is running.
@@ -52,6 +55,8 @@ type IAMConfig struct {
HostID string
// onProcessedTask is called after a task is processed.
onProcessedTask func(processedTask iamTask, processError error)
+ // awsClients is an SDK client provider.
+ awsClients awsClientProvider
}
// Check validates the IAM configurator config.
@@ -62,6 +67,9 @@ func (c *IAMConfig) Check() error {
if c.AccessPoint == nil {
return trace.BadParameter("missing AccessPoint")
}
+ if c.AWSConfigProvider == nil {
+ return trace.BadParameter("missing AWSConfigProvider")
+ }
if c.Clients == nil {
cloudClients, err := cloud.NewClients()
if err != nil {
@@ -72,6 +80,9 @@ func (c *IAMConfig) Check() error {
if c.HostID == "" {
return trace.BadParameter("missing HostID")
}
+ if c.awsClients == nil {
+ c.awsClients = defaultAWSClients{}
+ }
return nil
}
@@ -233,10 +244,12 @@ func (c *IAM) getAWSConfigurator(ctx context.Context, database types.Database) (
return nil, trace.Wrap(err)
}
return newAWS(ctx, awsConfig{
- clients: c.cfg.Clients,
- policyName: policyName,
- identity: identity,
- database: database,
+ awsConfigProvider: c.cfg.AWSConfigProvider,
+ clients: c.cfg.Clients,
+ database: database,
+ identity: identity,
+ policyName: policyName,
+ awsClients: c.cfg.awsClients,
})
}
diff --git a/lib/srv/db/cloud/iam_test.go b/lib/srv/db/cloud/iam_test.go
index d13d1fc74b86c..36397d6a64727 100644
--- a/lib/srv/db/cloud/iam_test.go
+++ b/lib/srv/db/cloud/iam_test.go
@@ -24,10 +24,10 @@ import (
"testing"
"time"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/iam"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
@@ -46,26 +46,28 @@ func TestAWSIAM(t *testing.T) {
t.Cleanup(cancel)
// Setup AWS database objects.
- rdsInstance := &rds.DBInstance{
+ rdsInstance := &rdstypes.DBInstance{
DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:postgres-rds"),
DBInstanceIdentifier: aws.String("postgres-rds"),
DbiResourceId: aws.String("db-xyz"),
}
- auroraCluster := &rds.DBCluster{
+ auroraCluster := &rdstypes.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:postgres-aurora"),
DBClusterIdentifier: aws.String("postgres-aurora"),
DbClusterResourceId: aws.String("cluster-xyz"),
}
// Configure mocks.
- stsClient := &mocks.STSClientV1{
- ARN: "arn:aws:iam::123456789012:role/test-role",
+ stsClient := &mocks.STSClient{
+ STSClientV1: mocks.STSClientV1{
+ ARN: "arn:aws:iam::123456789012:role/test-role",
+ },
}
- rdsClient := &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance},
- DBClusters: []*rds.DBCluster{auroraCluster},
+ clt := &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster},
}
iamClient := &mocks.IAMMock{}
@@ -152,15 +154,20 @@ func TestAWSIAM(t *testing.T) {
}
configurator, err := NewIAM(ctx, IAMConfig{
AccessPoint: &mockAccessPoint{},
+ AWSConfigProvider: &mocks.AWSConfigProvider{
+ STSClient: stsClient,
+ },
Clients: &clients.TestCloudClients{
- RDS: rdsClient,
- STS: stsClient,
+ STS: &stsClient.STSClientV1,
IAM: iamClient,
},
HostID: "host-id",
onProcessedTask: func(iamTask, error) {
taskChan <- struct{}{}
},
+ awsClients: fakeAWSClients{
+ rdsClient: clt,
+ },
})
require.NoError(t, err)
require.NoError(t, configurator.Start(ctx))
@@ -177,6 +184,7 @@ func TestAWSIAM(t *testing.T) {
database: rdsDatabase,
wantPolicyContains: rdsDatabase.GetAWS().RDS.ResourceID,
getIAMAuthEnabled: func() bool {
+ rdsInstance := &clt.DBInstances[0]
out := aws.BoolValue(rdsInstance.IAMDatabaseAuthenticationEnabled)
// reset it
rdsInstance.IAMDatabaseAuthenticationEnabled = aws.Bool(false)
@@ -187,6 +195,7 @@ func TestAWSIAM(t *testing.T) {
database: auroraDatabase,
wantPolicyContains: auroraDatabase.GetAWS().RDS.ResourceID,
getIAMAuthEnabled: func() bool {
+ auroraCluster := &clt.DBClusters[0]
out := aws.BoolValue(auroraCluster.IAMDatabaseAuthenticationEnabled)
// reset it
auroraCluster.IAMDatabaseAuthenticationEnabled = aws.Bool(false)
@@ -291,6 +300,16 @@ func TestAWSIAMNoPermissions(t *testing.T) {
AccessPoint: &mockAccessPoint{},
Clients: &clients.TestCloudClients{}, // placeholder,
HostID: "host-id",
+ AWSConfigProvider: &mocks.AWSConfigProvider{
+ STSClient: &mocks.STSClient{
+ STSClientV1: mocks.STSClientV1{
+ ARN: "arn:aws:iam::123456789012:role/test-role",
+ },
+ },
+ },
+ awsClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{Unauth: true},
+ },
})
require.NoError(t, err)
@@ -303,7 +322,6 @@ func TestAWSIAMNoPermissions(t *testing.T) {
name: "RDS database",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{InstanceID: "postgres-rds", ResourceID: "postgres-rds-resource-id"}},
clients: &clients.TestCloudClients{
- RDS: &mocks.RDSMockUnauth{},
IAM: &mocks.IAMErrorMock{
Error: trace.AccessDenied("unauthorized"),
},
@@ -314,7 +332,6 @@ func TestAWSIAMNoPermissions(t *testing.T) {
name: "Aurora cluster",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{ClusterID: "postgres-aurora", ResourceID: "postgres-aurora-resource-id"}},
clients: &clients.TestCloudClients{
- RDS: &mocks.RDSMockUnauth{},
IAM: &mocks.IAMErrorMock{
Error: trace.AccessDenied("unauthorized"),
},
@@ -325,7 +342,6 @@ func TestAWSIAMNoPermissions(t *testing.T) {
name: "RDS database missing metadata",
meta: types.AWS{Region: "localhost", RDS: types.RDS{ClusterID: "postgres-aurora"}},
clients: &clients.TestCloudClients{
- RDS: &mocks.RDSMockUnauth{},
IAM: &mocks.IAMErrorMock{
Error: trace.AccessDenied("unauthorized"),
},
@@ -416,6 +432,7 @@ func (m *mockAccessPoint) GetClusterName(opts ...services.MarshalOption) (types.
ClusterID: "cluster-id",
})
}
+
func (m *mockAccessPoint) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) {
return &types.SemaphoreLease{
SemaphoreKind: params.SemaphoreKind,
@@ -424,6 +441,7 @@ func (m *mockAccessPoint) AcquireSemaphore(ctx context.Context, params types.Acq
Expires: params.Expires,
}, nil
}
+
func (m *mockAccessPoint) CancelSemaphoreLease(ctx context.Context, lease types.SemaphoreLease) error {
return nil
}
diff --git a/lib/srv/db/cloud/meta.go b/lib/srv/db/cloud/meta.go
index 98e2280fb1db5..0956759422b07 100644
--- a/lib/srv/db/cloud/meta.go
+++ b/lib/srv/db/cloud/meta.go
@@ -24,14 +24,14 @@ import (
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go-v2/service/redshift"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface"
"github.com/aws/aws-sdk-go/service/memorydb"
"github.com/aws/aws-sdk-go/service/memorydb/memorydbiface"
- "github.com/aws/aws-sdk-go/service/rds"
- "github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/aws/aws-sdk-go/service/redshiftserverless"
"github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface"
"github.com/gravitational/trace"
@@ -45,13 +45,36 @@ import (
logutils "github.com/gravitational/teleport/lib/utils/log"
)
+// rdsClient defines a subset of the AWS RDS client API.
+type rdsClient interface {
+ rds.DescribeDBClustersAPIClient
+ rds.DescribeDBInstancesAPIClient
+ rds.DescribeDBProxiesAPIClient
+ rds.DescribeDBProxyEndpointsAPIClient
+ ModifyDBCluster(ctx context.Context, params *rds.ModifyDBClusterInput, optFns ...func(*rds.Options)) (*rds.ModifyDBClusterOutput, error)
+ ModifyDBInstance(ctx context.Context, params *rds.ModifyDBInstanceInput, optFns ...func(*rds.Options)) (*rds.ModifyDBInstanceOutput, error)
+}
+
// redshiftClient defines a subset of the AWS Redshift client API.
type redshiftClient interface {
redshift.DescribeClustersAPIClient
}
-// redshiftClientProviderFunc provides a [redshiftClient].
-type redshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient
+// awsClientProvider is an AWS SDK client provider.
+type awsClientProvider interface {
+ getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient
+ getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient
+}
+
+type defaultAWSClients struct{}
+
+func (defaultAWSClients) getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient {
+ return rds.NewFromConfig(cfg, optFns...)
+}
+
+func (defaultAWSClients) getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient {
+ return redshift.NewFromConfig(cfg, optFns...)
+}
// MetadataConfig is the cloud metadata service config.
type MetadataConfig struct {
@@ -60,9 +83,8 @@ type MetadataConfig struct {
// AWSConfigProvider provides [aws.Config] for AWS SDK service clients.
AWSConfigProvider awsconfig.Provider
- // redshiftClientProviderFn is an internal-only [redshiftClient] provider
- // func that is only set in tests.
- redshiftClientProviderFn redshiftClientProviderFunc
+ // awsClients is an SDK client provider.
+ awsClients awsClientProvider
}
// Check validates the metadata service config.
@@ -78,10 +100,8 @@ func (c *MetadataConfig) Check() error {
return trace.BadParameter("missing AWSConfigProvider")
}
- if c.redshiftClientProviderFn == nil {
- c.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient {
- return redshift.NewFromConfig(cfg, optFns...)
- }
+ if c.awsClients == nil {
+ c.awsClients = defaultAWSClients{}
}
return nil
}
@@ -147,20 +167,21 @@ func (m *Metadata) updateAWS(ctx context.Context, database types.Database, fetch
// fetchRDSMetadata fetches metadata for the provided RDS or Aurora database.
func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database) (*types.AWS, error) {
meta := database.GetAWS()
- rds, err := m.cfg.Clients.GetAWSRDSClient(ctx, meta.Region,
- cloud.WithAssumeRoleFromAWSMeta(meta),
- cloud.WithAmbientCredentials(),
+ awsCfg, err := m.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region,
+ awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID),
+ awsconfig.WithAmbientCredentials(),
)
if err != nil {
return nil, trace.Wrap(err)
}
+ clt := m.cfg.awsClients.getRDSClient(awsCfg)
if meta.RDS.ClusterID != "" {
- return fetchRDSClusterMetadata(ctx, rds, meta.RDS.ClusterID)
+ return fetchRDSClusterMetadata(ctx, clt, meta.RDS.ClusterID)
}
// Try to fetch the RDS instance fetchedMeta.
- fetchedMeta, err := fetchRDSInstanceMetadata(ctx, rds, meta.RDS.InstanceID)
+ fetchedMeta, err := fetchRDSInstanceMetadata(ctx, clt, meta.RDS.InstanceID)
if err != nil && !trace.IsNotFound(err) && !trace.IsAccessDenied(err) {
return nil, trace.Wrap(err)
}
@@ -172,11 +193,11 @@ func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database
if clusterID == "" {
clusterID = meta.RDS.InstanceID
}
- return fetchRDSClusterMetadata(ctx, rds, clusterID)
+ return fetchRDSClusterMetadata(ctx, clt, clusterID)
}
// If instance was found, it may be a part of an Aurora cluster.
if fetchedMeta.RDS.ClusterID != "" {
- return fetchRDSClusterMetadata(ctx, rds, fetchedMeta.RDS.ClusterID)
+ return fetchRDSClusterMetadata(ctx, clt, fetchedMeta.RDS.ClusterID)
}
return fetchedMeta, nil
}
@@ -184,18 +205,19 @@ func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database
// fetchRDSProxyMetadata fetches metadata for the provided RDS Proxy database.
func (m *Metadata) fetchRDSProxyMetadata(ctx context.Context, database types.Database) (*types.AWS, error) {
meta := database.GetAWS()
- rds, err := m.cfg.Clients.GetAWSRDSClient(ctx, meta.Region,
- cloud.WithAssumeRoleFromAWSMeta(meta),
- cloud.WithAmbientCredentials(),
+ awsCfg, err := m.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region,
+ awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID),
+ awsconfig.WithAmbientCredentials(),
)
if err != nil {
return nil, trace.Wrap(err)
}
+ clt := m.cfg.awsClients.getRDSClient(awsCfg)
if meta.RDSProxy.CustomEndpointName != "" {
- return fetchRDSProxyCustomEndpointMetadata(ctx, rds, meta.RDSProxy.CustomEndpointName, database.GetURI())
+ return fetchRDSProxyCustomEndpointMetadata(ctx, clt, meta.RDSProxy.CustomEndpointName, database.GetURI())
}
- return fetchRDSProxyMetadata(ctx, rds, meta.RDSProxy.Name)
+ return fetchRDSProxyMetadata(ctx, clt, meta.RDSProxy.Name)
}
// fetchRedshiftMetadata fetches metadata for the provided Redshift database.
@@ -208,7 +230,7 @@ func (m *Metadata) fetchRedshiftMetadata(ctx context.Context, database types.Dat
if err != nil {
return nil, trace.Wrap(err)
}
- redshift := m.cfg.redshiftClientProviderFn(awsCfg)
+ redshift := m.cfg.awsClients.getRedshiftClient(awsCfg)
cluster, err := describeRedshiftCluster(ctx, redshift, meta.Redshift.ClusterID)
if err != nil {
return nil, trace.Wrap(err)
@@ -275,8 +297,8 @@ func (m *Metadata) fetchMemoryDBMetadata(ctx context.Context, database types.Dat
}
// fetchRDSInstanceMetadata fetches metadata about specified RDS instance.
-func fetchRDSInstanceMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, instanceID string) (*types.AWS, error) {
- rdsInstance, err := describeRDSInstance(ctx, rdsClient, instanceID)
+func fetchRDSInstanceMetadata(ctx context.Context, clt rdsClient, instanceID string) (*types.AWS, error) {
+ rdsInstance, err := describeRDSInstance(ctx, clt, instanceID)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -284,22 +306,22 @@ func fetchRDSInstanceMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, in
}
// describeRDSInstance returns AWS RDS instance for the specified ID.
-func describeRDSInstance(ctx context.Context, rdsClient rdsiface.RDSAPI, instanceID string) (*rds.DBInstance, error) {
- out, err := rdsClient.DescribeDBInstancesWithContext(ctx, &rds.DescribeDBInstancesInput{
+func describeRDSInstance(ctx context.Context, clt rdsClient, instanceID string) (*rdstypes.DBInstance, error) {
+ out, err := clt.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{
DBInstanceIdentifier: aws.String(instanceID),
})
if err != nil {
return nil, common.ConvertError(err)
}
if len(out.DBInstances) != 1 {
- return nil, trace.BadParameter("expected 1 RDS instance for %v, got %+v", instanceID, out.DBInstances)
+ return nil, trace.BadParameter("expected 1 RDS instance for %v, got %d", instanceID, len(out.DBInstances))
}
- return out.DBInstances[0], nil
+ return &out.DBInstances[0], nil
}
// fetchRDSClusterMetadata fetches metadata about specified Aurora cluster.
-func fetchRDSClusterMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterID string) (*types.AWS, error) {
- rdsCluster, err := describeRDSCluster(ctx, rdsClient, clusterID)
+func fetchRDSClusterMetadata(ctx context.Context, clt rdsClient, clusterID string) (*types.AWS, error) {
+ rdsCluster, err := describeRDSCluster(ctx, clt, clusterID)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -307,8 +329,8 @@ func fetchRDSClusterMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, clu
}
// describeRDSCluster returns AWS Aurora cluster for the specified ID.
-func describeRDSCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterID string) (*rds.DBCluster, error) {
- out, err := rdsClient.DescribeDBClustersWithContext(ctx, &rds.DescribeDBClustersInput{
+func describeRDSCluster(ctx context.Context, clt rdsClient, clusterID string) (*rdstypes.DBCluster, error) {
+ out, err := clt.DescribeDBClusters(ctx, &rds.DescribeDBClustersInput{
DBClusterIdentifier: aws.String(clusterID),
})
if err != nil {
@@ -317,7 +339,7 @@ func describeRDSCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterI
if len(out.DBClusters) != 1 {
return nil, trace.BadParameter("expected 1 RDS cluster for %v, got %+v", clusterID, out.DBClusters)
}
- return out.DBClusters[0], nil
+ return &out.DBClusters[0], nil
}
// describeRedshiftCluster returns AWS Redshift cluster for the specified ID.
@@ -364,8 +386,8 @@ func describeMemoryDBCluster(ctx context.Context, client memorydbiface.MemoryDBA
}
// fetchRDSProxyMetadata fetches metadata about specified RDS Proxy name.
-func fetchRDSProxyMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyName string) (*types.AWS, error) {
- rdsProxy, err := describeRDSProxy(ctx, rdsClient, proxyName)
+func fetchRDSProxyMetadata(ctx context.Context, clt rdsClient, proxyName string) (*types.AWS, error) {
+ rdsProxy, err := describeRDSProxy(ctx, clt, proxyName)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -373,28 +395,28 @@ func fetchRDSProxyMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxy
}
// describeRDSProxy returns AWS RDS Proxy for the specified RDS Proxy name.
-func describeRDSProxy(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyName string) (*rds.DBProxy, error) {
- out, err := rdsClient.DescribeDBProxiesWithContext(ctx, &rds.DescribeDBProxiesInput{
+func describeRDSProxy(ctx context.Context, clt rdsClient, proxyName string) (*rdstypes.DBProxy, error) {
+ out, err := clt.DescribeDBProxies(ctx, &rds.DescribeDBProxiesInput{
DBProxyName: aws.String(proxyName),
})
if err != nil {
return nil, common.ConvertError(err)
}
if len(out.DBProxies) != 1 {
- return nil, trace.BadParameter("expected 1 RDS Proxy for %v, got %s", proxyName, out.DBProxies)
+ return nil, trace.BadParameter("expected 1 RDS Proxy for %v, got %d", proxyName, len(out.DBProxies))
}
- return out.DBProxies[0], nil
+ return &out.DBProxies[0], nil
}
// fetchRDSProxyCustomEndpointMetadata fetches metadata about specified RDS
// proxy custom endpoint.
-func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*types.AWS, error) {
- rdsProxyEndpoint, err := describeRDSProxyCustomEndpointAndFindURI(ctx, rdsClient, proxyEndpointName, uri)
+func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, clt rdsClient, proxyEndpointName, uri string) (*types.AWS, error) {
+ rdsProxyEndpoint, err := describeRDSProxyCustomEndpointAndFindURI(ctx, clt, proxyEndpointName, uri)
if err != nil {
return nil, trace.Wrap(err)
}
- rdsProxy, err := describeRDSProxy(ctx, rdsClient, aws.ToString(rdsProxyEndpoint.DBProxyName))
+ rdsProxy, err := describeRDSProxy(ctx, clt, aws.ToString(rdsProxyEndpoint.DBProxyName))
if err != nil {
return nil, trace.Wrap(err)
}
@@ -404,21 +426,27 @@ func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface
// describeRDSProxyCustomEndpointAndFindURI returns AWS RDS Proxy endpoint for
// the specified RDS Proxy custom endpoint.
-func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*rds.DBProxyEndpoint, error) {
- out, err := rdsClient.DescribeDBProxyEndpointsWithContext(ctx, &rds.DescribeDBProxyEndpointsInput{
+func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, clt rdsClient, proxyEndpointName, uri string) (*rdstypes.DBProxyEndpoint, error) {
+ out, err := clt.DescribeDBProxyEndpoints(ctx, &rds.DescribeDBProxyEndpointsInput{
DBProxyEndpointName: aws.String(proxyEndpointName),
})
if err != nil {
return nil, common.ConvertError(err)
}
- for _, customEndpoint := range out.DBProxyEndpoints {
+ var endpoints []string
+ for _, e := range out.DBProxyEndpoints {
+ endpoint := aws.ToString(e.Endpoint)
+ if endpoint == "" {
+ continue
+ }
// Double check if it has the same URI in case multiple custom
// endpoints have the same name.
- if strings.Contains(uri, aws.ToString(customEndpoint.Endpoint)) {
- return customEndpoint, nil
+ if strings.Contains(uri, endpoint) {
+ return &e, nil
}
+ endpoints = append(endpoints, endpoint)
}
- return nil, trace.BadParameter("could not find RDS Proxy custom endpoint %v with URI %v, got %s", proxyEndpointName, uri, out.DBProxyEndpoints)
+ return nil, trace.BadParameter("could not find RDS Proxy custom endpoint %v with URI %v, got %s", proxyEndpointName, uri, endpoints)
}
func fetchRedshiftServerlessWorkgroupMetadata(ctx context.Context, client redshiftserverlessiface.RedshiftServerlessAPI, workgroupName string) (*types.AWS, error) {
diff --git a/lib/srv/db/cloud/meta_test.go b/lib/srv/db/cloud/meta_test.go
index 9e66a416a2ebb..9c8805f026820 100644
--- a/lib/srv/db/cloud/meta_test.go
+++ b/lib/srv/db/cloud/meta_test.go
@@ -23,11 +23,12 @@ import (
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go-v2/service/redshift"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/memorydb"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/redshiftserverless"
"github.com/stretchr/testify/require"
@@ -40,8 +41,8 @@ import (
// TestAWSMetadata tests fetching AWS metadata for RDS and Redshift databases.
func TestAWSMetadata(t *testing.T) {
// Configure RDS API mock.
- rds := &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{
+ rdsClt := &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{
// Standalone RDS instance.
{
DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:postgres-rds"),
@@ -56,7 +57,7 @@ func TestAWSMetadata(t *testing.T) {
DBClusterIdentifier: aws.String("postgres-aurora"),
},
},
- DBClusters: []*rds.DBCluster{
+ DBClusters: []rdstypes.DBCluster{
// Aurora cluster.
{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:postgres-aurora"),
@@ -64,16 +65,17 @@ func TestAWSMetadata(t *testing.T) {
DbClusterResourceId: aws.String("cluster-xyz"),
},
},
- DBProxies: []*rds.DBProxy{
+ DBProxies: []rdstypes.DBProxy{
{
DBProxyArn: aws.String("arn:aws:rds:us-east-1:123456789012:db-proxy:prx-resource-id"),
DBProxyName: aws.String("rds-proxy"),
},
},
- DBProxyEndpoints: []*rds.DBProxyEndpoint{
+ DBProxyEndpoints: []rdstypes.DBProxyEndpoint{
{
DBProxyEndpointName: aws.String("rds-proxy-endpoint"),
DBProxyName: aws.String("rds-proxy"),
+ Endpoint: aws.String("localhost"),
},
},
}
@@ -130,7 +132,6 @@ func TestAWSMetadata(t *testing.T) {
// Create metadata fetcher.
metadata, err := NewMetadata(MetadataConfig{
Clients: &cloud.TestCloudClients{
- RDS: rds,
ElastiCache: elasticache,
MemoryDB: memorydb,
RedshiftServerless: redshiftServerless,
@@ -139,7 +140,10 @@ func TestAWSMetadata(t *testing.T) {
AWSConfigProvider: &mocks.AWSConfigProvider{
STSClient: fakeSTS,
},
- redshiftClientProviderFn: newFakeRedshiftClientProvider(redshiftClt),
+ awsClients: fakeAWSClients{
+ rdsClient: rdsClt,
+ redshiftClient: redshiftClt,
+ },
})
require.NoError(t, err)
@@ -407,7 +411,7 @@ func TestAWSMetadata(t *testing.T) {
// cause an error.
func TestAWSMetadataNoPermissions(t *testing.T) {
// Create unauthorized mocks.
- rds := &mocks.RDSMockUnauth{}
+ rdsClt := &mocks.RDSClient{Unauth: true}
redshiftClt := &mocks.RedshiftClient{Unauth: true}
fakeSTS := &mocks.STSClient{}
@@ -415,13 +419,15 @@ func TestAWSMetadataNoPermissions(t *testing.T) {
// Create metadata fetcher.
metadata, err := NewMetadata(MetadataConfig{
Clients: &cloud.TestCloudClients{
- RDS: rds,
STS: &fakeSTS.STSClientV1,
},
AWSConfigProvider: &mocks.AWSConfigProvider{
STSClient: fakeSTS,
},
- redshiftClientProviderFn: newFakeRedshiftClientProvider(redshiftClt),
+ awsClients: fakeAWSClients{
+ rdsClient: rdsClt,
+ redshiftClient: redshiftClt,
+ },
})
require.NoError(t, err)
@@ -494,8 +500,15 @@ func TestAWSMetadataNoPermissions(t *testing.T) {
}
}
-func newFakeRedshiftClientProvider(c redshiftClient) redshiftClientProviderFunc {
- return func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient {
- return c
- }
+type fakeAWSClients struct {
+ rdsClient rdsClient
+ redshiftClient redshiftClient
+}
+
+func (f fakeAWSClients) getRDSClient(aws.Config, ...func(*rds.Options)) rdsClient {
+ return f.rdsClient
+}
+
+func (f fakeAWSClients) getRedshiftClient(aws.Config, ...func(*redshift.Options)) redshiftClient {
+ return f.redshiftClient
}
diff --git a/lib/srv/db/cloud/resource_checker_url.go b/lib/srv/db/cloud/resource_checker_url.go
index fdc4efdb65fe9..da8dd40fac772 100644
--- a/lib/srv/db/cloud/resource_checker_url.go
+++ b/lib/srv/db/cloud/resource_checker_url.go
@@ -28,7 +28,6 @@ import (
"sync"
"github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/redshift"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types"
@@ -42,9 +41,8 @@ import (
type urlChecker struct {
// awsConfigProvider provides [aws.Config] for AWS SDK service clients.
awsConfigProvider awsconfig.Provider
- // redshiftClientProviderFn is an internal-only [redshiftClient] provider
- // func that is only set in tests.
- redshiftClientProviderFn redshiftClientProviderFunc
+ // awsClients is an SDK client provider.
+ awsClients awsClientProvider
clients cloud.Clients
logger *slog.Logger
@@ -61,12 +59,10 @@ type urlChecker struct {
func newURLChecker(cfg DiscoveryResourceCheckerConfig) *urlChecker {
return &urlChecker{
awsConfigProvider: cfg.AWSConfigProvider,
- redshiftClientProviderFn: func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient {
- return redshift.NewFromConfig(cfg, optFns...)
- },
- clients: cfg.Clients,
- logger: cfg.Logger,
- warnOnError: getWarnOnError(),
+ awsClients: defaultAWSClients{},
+ clients: cfg.Clients,
+ logger: cfg.Logger,
+ warnOnError: getWarnOnError(),
}
}
diff --git a/lib/srv/db/cloud/resource_checker_url_aws.go b/lib/srv/db/cloud/resource_checker_url_aws.go
index 336ee197815fb..5b87d643ea7b7 100644
--- a/lib/srv/db/cloud/resource_checker_url_aws.go
+++ b/lib/srv/db/cloud/resource_checker_url_aws.go
@@ -21,10 +21,9 @@ package cloud
import (
"context"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/opensearchservice"
- "github.com/aws/aws-sdk-go/service/rds"
- "github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface"
"github.com/gravitational/trace"
@@ -82,22 +81,23 @@ func (c *urlChecker) logAWSAccessDeniedError(ctx context.Context, database types
func (c *urlChecker) checkRDS(ctx context.Context, database types.Database) error {
meta := database.GetAWS()
- rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region,
- cloud.WithAssumeRoleFromAWSMeta(meta),
- cloud.WithAmbientCredentials(),
+ awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region,
+ awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID),
+ awsconfig.WithAmbientCredentials(),
)
if err != nil {
return trace.Wrap(err)
}
+ clt := c.awsClients.getRDSClient(awsCfg)
if meta.RDS.ClusterID != "" {
- return trace.Wrap(c.checkRDSCluster(ctx, database, rdsClient, meta.RDS.ClusterID))
+ return trace.Wrap(c.checkRDSCluster(ctx, database, clt, meta.RDS.ClusterID))
}
- return trace.Wrap(c.checkRDSInstance(ctx, database, rdsClient, meta.RDS.InstanceID))
+ return trace.Wrap(c.checkRDSInstance(ctx, database, clt, meta.RDS.InstanceID))
}
-func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, instanceID string) error {
- rdsInstance, err := describeRDSInstance(ctx, rdsClient, instanceID)
+func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Database, clt rdsClient, instanceID string) error {
+ rdsInstance, err := describeRDSInstance(ctx, clt, instanceID)
if err != nil {
return trace.Wrap(err)
}
@@ -107,12 +107,12 @@ func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Databa
return trace.Wrap(requireDatabaseAddressPort(database, rdsInstance.Endpoint.Address, rdsInstance.Endpoint.Port))
}
-func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, clusterID string) error {
- rdsCluster, err := describeRDSCluster(ctx, rdsClient, clusterID)
+func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Database, clt rdsClient, clusterID string) error {
+ rdsCluster, err := describeRDSCluster(ctx, clt, clusterID)
if err != nil {
return trace.Wrap(err)
}
- databases, err := common.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{})
+ databases, err := common.NewDatabasesFromRDSCluster(rdsCluster, []rdstypes.DBInstance{})
if err != nil {
c.logger.WarnContext(ctx, "Could not convert RDS cluster to database resources",
"cluster", aws.StringValue(rdsCluster.DBClusterIdentifier),
@@ -130,21 +130,22 @@ func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Databas
func (c *urlChecker) checkRDSProxy(ctx context.Context, database types.Database) error {
meta := database.GetAWS()
- rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region,
- cloud.WithAssumeRoleFromAWSMeta(meta),
- cloud.WithAmbientCredentials(),
+ awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region,
+ awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID),
+ awsconfig.WithAmbientCredentials(),
)
if err != nil {
return trace.Wrap(err)
}
+ clt := c.awsClients.getRDSClient(awsCfg)
if meta.RDSProxy.CustomEndpointName != "" {
- return trace.Wrap(c.checkRDSProxyCustomEndpoint(ctx, database, rdsClient, meta.RDSProxy.CustomEndpointName))
+ return trace.Wrap(c.checkRDSProxyCustomEndpoint(ctx, database, clt, meta.RDSProxy.CustomEndpointName))
}
- return trace.Wrap(c.checkRDSProxyPrimaryEndpoint(ctx, database, rdsClient, meta.RDSProxy.Name))
+ return trace.Wrap(c.checkRDSProxyPrimaryEndpoint(ctx, database, clt, meta.RDSProxy.Name))
}
-func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, proxyName string) error {
- rdsProxy, err := describeRDSProxy(ctx, rdsClient, proxyName)
+func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database types.Database, clt rdsClient, proxyName string) error {
+ rdsProxy, err := describeRDSProxy(ctx, clt, proxyName)
if err != nil {
return trace.Wrap(err)
}
@@ -153,8 +154,8 @@ func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database
return requireDatabaseHost(database, aws.StringValue(rdsProxy.Endpoint))
}
-func (c *urlChecker) checkRDSProxyCustomEndpoint(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, proxyEndpointName string) error {
- _, err := describeRDSProxyCustomEndpointAndFindURI(ctx, rdsClient, proxyEndpointName, database.GetURI())
+func (c *urlChecker) checkRDSProxyCustomEndpoint(ctx context.Context, database types.Database, clt rdsClient, proxyEndpointName string) error {
+ _, err := describeRDSProxyCustomEndpointAndFindURI(ctx, clt, proxyEndpointName, database.GetURI())
return trace.Wrap(err)
}
@@ -167,7 +168,7 @@ func (c *urlChecker) checkRedshift(ctx context.Context, database types.Database)
if err != nil {
return trace.Wrap(err)
}
- redshift := c.redshiftClientProviderFn(awsCfg)
+ redshift := c.awsClients.getRedshiftClient(awsCfg)
cluster, err := describeRedshiftCluster(ctx, redshift, meta.Redshift.ClusterID)
if err != nil {
return trace.Wrap(err)
@@ -290,15 +291,16 @@ func (c *urlChecker) checkOpenSearchEndpoint(ctx context.Context, database types
func (c *urlChecker) checkDocumentDB(ctx context.Context, database types.Database) error {
meta := database.GetAWS()
- rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region,
- cloud.WithAssumeRoleFromAWSMeta(meta),
- cloud.WithAmbientCredentials(),
+ awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region,
+ awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID),
+ awsconfig.WithAmbientCredentials(),
)
if err != nil {
return trace.Wrap(err)
}
+ clt := c.awsClients.getRDSClient(awsCfg)
- cluster, err := describeRDSCluster(ctx, rdsClient, meta.DocumentDB.ClusterID)
+ cluster, err := describeRDSCluster(ctx, clt, meta.DocumentDB.ClusterID)
if err != nil {
return trace.Wrap(err)
}
diff --git a/lib/srv/db/cloud/resource_checker_url_aws_test.go b/lib/srv/db/cloud/resource_checker_url_aws_test.go
index e8ba24f624c16..40095f7efafe0 100644
--- a/lib/srv/db/cloud/resource_checker_url_aws_test.go
+++ b/lib/srv/db/cloud/resource_checker_url_aws_test.go
@@ -22,11 +22,11 @@ import (
"context"
"testing"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/memorydb"
"github.com/aws/aws-sdk-go/service/opensearchservice"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/redshiftserverless"
"github.com/stretchr/testify/require"
@@ -54,7 +54,7 @@ func TestURLChecker_AWS(t *testing.T) {
mocks.WithRDSClusterReader,
mocks.WithRDSClusterCustomEndpoint("my-custom"),
)
- rdsClusterDBs, err := common.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{})
+ rdsClusterDBs, err := common.NewDatabasesFromRDSCluster(rdsCluster, []rdstypes.DBInstance{})
require.NoError(t, err)
require.Len(t, rdsClusterDBs, 3) // Primary, reader, custom.
testCases = append(testCases, append(rdsClusterDBs, rdsInstanceDB)...)
@@ -121,12 +121,6 @@ func TestURLChecker_AWS(t *testing.T) {
// Mock cloud clients.
mockClients := &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance},
- DBClusters: []*rds.DBCluster{rdsCluster, docdbCluster},
- DBProxies: []*rds.DBProxy{rdsProxy},
- DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyCustomEndpoint},
- },
RedshiftServerless: &mocks.RedshiftServerlessMock{
Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup},
Endpoints: []*redshiftserverless.EndpointAccess{redshiftServerlessVPCEndpoint},
@@ -143,7 +137,6 @@ func TestURLChecker_AWS(t *testing.T) {
STS: &mocks.STSClientV1{},
}
mockClientsUnauth := &cloud.TestCloudClients{
- RDS: &mocks.RDSMockUnauth{},
RedshiftServerless: &mocks.RedshiftServerlessMock{Unauth: true},
ElastiCache: &mocks.ElastiCacheMock{Unauth: true},
MemoryDB: &mocks.MemoryDBMock{Unauth: true},
@@ -158,21 +151,32 @@ func TestURLChecker_AWS(t *testing.T) {
name string
clients cloud.Clients
awsConfigProvider awsconfig.Provider
- redshiftClient redshiftClient
+ awsClients awsClientProvider
}{
{
name: "API check",
clients: mockClients,
awsConfigProvider: &mocks.AWSConfigProvider{},
- redshiftClient: &mocks.RedshiftClient{
- Clusters: []redshifttypes.Cluster{redshiftCluster},
+ awsClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance},
+ DBClusters: []rdstypes.DBCluster{*rdsCluster, *docdbCluster},
+ DBProxies: []rdstypes.DBProxy{*rdsProxy},
+ DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyCustomEndpoint},
+ },
+ redshiftClient: &mocks.RedshiftClient{
+ Clusters: []redshifttypes.Cluster{redshiftCluster},
+ },
},
},
{
name: "basic endpoint check",
clients: mockClientsUnauth,
awsConfigProvider: &mocks.AWSConfigProvider{},
- redshiftClient: &mocks.RedshiftClient{Unauth: true},
+ awsClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{Unauth: true},
+ redshiftClient: &mocks.RedshiftClient{Unauth: true},
+ },
},
}
@@ -183,7 +187,7 @@ func TestURLChecker_AWS(t *testing.T) {
AWSConfigProvider: method.awsConfigProvider,
Logger: utils.NewSlogLoggerForTests(),
})
- c.redshiftClientProviderFn = newFakeRedshiftClientProvider(method.redshiftClient)
+ c.awsClients = method.awsClients
for _, database := range testCases {
t.Run(database.GetName(), func(t *testing.T) {
diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go
index e567d82d402e0..ad7183e70563b 100644
--- a/lib/srv/db/common/auth.go
+++ b/lib/srv/db/common/auth.go
@@ -35,12 +35,12 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/aws/aws-sdk-go-v2/aws"
+ rdsauth "github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/aws/aws-sdk-go-v2/service/redshift"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/memorydb"
- "github.com/aws/aws-sdk-go/service/rds/rdsutils"
"github.com/aws/aws-sdk-go/service/redshiftserverless"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
@@ -131,8 +131,16 @@ type redshiftClient interface {
GetClusterCredentials(context.Context, *redshift.GetClusterCredentialsInput, ...func(*redshift.Options)) (*redshift.GetClusterCredentialsOutput, error)
}
-// redshiftClientProviderFunc provides a [redshiftClient].
-type redshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient
+// awsClientProvider is an AWS SDK client provider.
+type awsClientProvider interface {
+ getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient
+}
+
+type defaultAWSClients struct{}
+
+func (defaultAWSClients) getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient {
+ return redshift.NewFromConfig(cfg, optFns...)
+}
// AuthConfig is the database access authenticator configuration.
type AuthConfig struct {
@@ -149,10 +157,8 @@ type AuthConfig struct {
// AWSConfigProvider provides [aws.Config] for AWS SDK service clients.
AWSConfigProvider awsconfig.Provider
- // redshiftClientProviderFn is an internal-only [redshiftClient] provider
- // func that defaults to a func that provides a real Redshift client.
- // The default is only overridden in tests.
- redshiftClientProviderFn redshiftClientProviderFunc
+ // awsClients is an SDK client provider.
+ awsClients awsClientProvider
}
// CheckAndSetDefaults validates the config and sets defaults.
@@ -176,10 +182,8 @@ func (c *AuthConfig) CheckAndSetDefaults() error {
c.Logger = slog.With(teleport.ComponentKey, "db:auth")
}
- if c.redshiftClientProviderFn == nil {
- c.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient {
- return redshift.NewFromConfig(cfg, optFns...)
- }
+ if c.awsClients == nil {
+ c.awsClients = defaultAWSClients{}
}
return nil
}
@@ -243,9 +247,9 @@ func (a *dbAuth) WithLogger(getUpdatedLogger func(*slog.Logger) *slog.Logger) Au
// when connecting to RDS and Aurora databases.
func (a *dbAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
meta := database.GetAWS()
- awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region,
- cloud.WithAssumeRoleFromAWSMeta(meta),
- cloud.WithAmbientCredentials(),
+ awsCfg, err := a.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region,
+ awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID),
+ awsconfig.WithAmbientCredentials(),
)
if err != nil {
return "", trace.Wrap(err)
@@ -254,11 +258,13 @@ func (a *dbAuth) GetRDSAuthToken(ctx context.Context, database types.Database, d
"database", database,
"database_user", databaseUser,
)
- token, err := rdsutils.BuildAuthToken(
+ token, err := rdsauth.BuildAuthToken(
+ ctx,
database.GetURI(),
meta.Region,
databaseUser,
- awsSession.Config.Credentials)
+ awsCfg.Credentials,
+ )
if err != nil {
policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(database)
if getPolicyErr != nil {
@@ -316,7 +322,7 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent
"database_user", databaseUser,
"database_name", databaseName,
)
- client := a.cfg.redshiftClientProviderFn(awsCfg)
+ client := a.cfg.awsClients.getRedshiftClient(awsCfg)
resp, err := client.GetClusterCredentialsWithIAM(ctx, &redshift.GetClusterCredentialsWithIAMInput{
ClusterIdentifier: aws.String(meta.Redshift.ClusterID),
DbName: aws.String(databaseName),
@@ -352,7 +358,7 @@ func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, database types.
"database_user", databaseUser,
"database_name", databaseName,
)
- clt := a.cfg.redshiftClientProviderFn(awsCfg)
+ clt := a.cfg.awsClients.getRedshiftClient(awsCfg)
resp, err := clt.GetClusterCredentials(ctx, &redshift.GetClusterCredentialsInput{
ClusterIdentifier: aws.String(meta.Redshift.ClusterID),
DbUser: aws.String(databaseUser),
diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go
index ae136b4d53c46..d85df87c5fd54 100644
--- a/lib/srv/db/common/auth_test.go
+++ b/lib/srv/db/common/auth_test.go
@@ -609,7 +609,6 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
AccessPoint: new(accessPointMock),
Clients: &cloud.TestCloudClients{
STS: &fakeSTS.STSClientV1,
- RDS: &mocks.RDSMock{},
RedshiftServerless: &mocks.RedshiftServerlessMock{
GetCredentialsOutput: mocks.RedshiftServerlessGetCredentialsOutput("IAM:some-user", "some-password", clock),
},
@@ -617,10 +616,12 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
AWSConfigProvider: &mocks.AWSConfigProvider{
STSClient: fakeSTS,
},
- redshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{
- GetClusterCredentialsOutput: mocks.RedshiftGetClusterCredentialsOutput("IAM:some-user", "some-password", clock),
- GetClusterCredentialsWithIAMOutput: mocks.RedshiftGetClusterCredentialsWithIAMOutput("IAM:some-role", "some-password-for-some-role", clock),
- }),
+ awsClients: fakeAWSClients{
+ redshiftClient: &mocks.RedshiftClient{
+ GetClusterCredentialsOutput: mocks.RedshiftGetClusterCredentialsOutput("IAM:some-user", "some-password", clock),
+ GetClusterCredentialsWithIAMOutput: mocks.RedshiftGetClusterCredentialsWithIAMOutput("IAM:some-role", "some-password-for-some-role", clock),
+ },
+ },
})
require.NoError(t, err)
@@ -957,8 +958,7 @@ func generateAzureVM(t *testing.T, identities []string) armcompute.VirtualMachin
}
// authClientMock is a mock that implements AuthClient interface.
-type authClientMock struct {
-}
+type authClientMock struct{}
// GenerateDatabaseCert generates a cert using fixtures TLS CA.
func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) {
@@ -996,8 +996,7 @@ func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.Da
}, nil
}
-type accessPointMock struct {
-}
+type accessPointMock struct{}
// GetAuthPreference always returns types.DefaultAuthPreference().
func (m accessPointMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
@@ -1022,8 +1021,10 @@ func (m *imdsMock) GetType() types.InstanceMetadataType {
return m.instanceType
}
-func newFakeRedshiftClientProvider(c redshiftClient) redshiftClientProviderFunc {
- return func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient {
- return c
- }
+type fakeAWSClients struct {
+ redshiftClient redshiftClient
+}
+
+func (f fakeAWSClients) getRedshiftClient(aws.Config, ...func(*redshift.Options)) redshiftClient {
+ return f.redshiftClient
}
diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go
index 28fcc486bf4db..dfb1a4b164192 100644
--- a/lib/srv/db/server.go
+++ b/lib/srv/db/server.go
@@ -259,9 +259,10 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) {
}
if c.CloudIAM == nil {
c.CloudIAM, err = cloud.NewIAM(ctx, cloud.IAMConfig{
- AccessPoint: c.AccessPoint,
- Clients: c.CloudClients,
- HostID: c.HostID,
+ AccessPoint: c.AccessPoint,
+ AWSConfigProvider: c.AWSConfigProvider,
+ Clients: c.CloudClients,
+ HostID: c.HostID,
})
if err != nil {
return trace.Wrap(err)
diff --git a/lib/srv/db/watcher_test.go b/lib/srv/db/watcher_test.go
index 8a7750a26a07a..6020547ea9590 100644
--- a/lib/srv/db/watcher_test.go
+++ b/lib/srv/db/watcher_test.go
@@ -320,7 +320,6 @@ func TestWatcherCloudFetchers(t *testing.T) {
reconcileCh <- d
},
CloudClients: &clients.TestCloudClients{
- RDS: &mocks.RDSMockUnauth{}, // Access denied error should not affect other fetchers.
RedshiftServerless: &mocks.RedshiftServerlessMock{
Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup},
},
@@ -358,7 +357,7 @@ func assertReconciledResource(t *testing.T, ch chan types.Databases, databases t
cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"),
))
case <-time.After(time.Second):
- t.Fatal("Didn't receive reconcile event after 1s.")
+ require.FailNow(t, "Didn't receive reconcile event after 1s.")
}
}
diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go
index 821408d2208fa..534644e6be1df 100644
--- a/lib/srv/desktop/rdp/rdpclient/client.go
+++ b/lib/srv/desktop/rdp/rdpclient/client.go
@@ -93,7 +93,7 @@ func init() {
var rustLogLevel string
// initialize the Rust logger by setting $RUST_LOG based
- // on the logrus log level
+ // on the slog log level
// (unless RUST_LOG is already explicitly set, then we
// assume the user knows what they want)
rl := os.Getenv("RUST_LOG")
diff --git a/lib/srv/discovery/access_graph.go b/lib/srv/discovery/access_graph.go
index 4bc207b21df01..f19d902068daf 100644
--- a/lib/srv/discovery/access_graph.go
+++ b/lib/srv/discovery/access_graph.go
@@ -501,7 +501,9 @@ func (s *Server) accessGraphFetchersFromMatchers(ctx context.Context, matchers M
fetcher, err := aws_sync.NewAWSFetcher(
ctx,
aws_sync.Config{
+ AWSConfigProvider: s.AWSConfigProvider,
CloudClients: s.CloudClients,
+ GetEKSClient: s.GetAWSSyncEKSClient,
GetEC2Client: s.GetEC2Client,
AssumeRole: assumeRole,
Regions: awsFetcher.Regions,
diff --git a/lib/srv/discovery/common/database.go b/lib/srv/discovery/common/database.go
index 8afe335f87fcb..dcff7a2c0f614 100644
--- a/lib/srv/discovery/common/database.go
+++ b/lib/srv/discovery/common/database.go
@@ -35,7 +35,6 @@ import (
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/memorydb"
"github.com/aws/aws-sdk-go/service/opensearchservice"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/redshiftserverless"
"github.com/gravitational/trace"
@@ -286,7 +285,7 @@ func NewDatabaseFromAzurePostgresFlexServer(server *armpostgresqlflexibleservers
}
// NewDatabaseFromRDSInstance creates a database resource from an RDS instance.
-func NewDatabaseFromRDSInstance(instance *rds.DBInstance) (types.Database, error) {
+func NewDatabaseFromRDSInstance(instance *rdstypes.DBInstance) (types.Database, error) {
endpoint := instance.Endpoint
if endpoint == nil {
return nil, trace.BadParameter("empty endpoint")
@@ -307,7 +306,7 @@ func NewDatabaseFromRDSInstance(instance *rds.DBInstance) (types.Database, error
}, aws.ToString(instance.DBInstanceIdentifier)),
types.DatabaseSpecV3{
Protocol: protocol,
- URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt64(endpoint.Port)),
+ URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt32(endpoint.Port)),
AWS: *metadata,
})
}
@@ -492,7 +491,7 @@ func labelsFromRDSV2Cluster(rdsCluster *rdstypes.DBCluster, meta *types.AWS, end
}
// NewDatabaseFromRDSCluster creates a database resource from an RDS cluster (Aurora).
-func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) {
+func NewDatabaseFromRDSCluster(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Database, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
@@ -508,13 +507,13 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DB
}, aws.ToString(cluster.DBClusterIdentifier)),
types.DatabaseSpecV3{
Protocol: protocol,
- URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt64(cluster.Port)),
+ URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt32(cluster.Port)),
AWS: *metadata,
})
}
// NewDatabaseFromRDSClusterReaderEndpoint creates a database resource from an RDS cluster reader endpoint (Aurora).
-func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) {
+func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Database, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
@@ -530,13 +529,13 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInsta
}, aws.ToString(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeReader),
types.DatabaseSpecV3{
Protocol: protocol,
- URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt64(cluster.Port)),
+ URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt32(cluster.Port)),
AWS: *metadata,
})
}
// NewDatabasesFromRDSClusterCustomEndpoints creates database resources from RDS cluster custom endpoints (Aurora).
-func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) {
+func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Databases, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
@@ -551,7 +550,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns
for _, endpoint := range cluster.CustomEndpoints {
// RDS custom endpoint format:
// .cluster-custom-.
- endpointDetails, err := apiawsutils.ParseRDSEndpoint(aws.ToString(endpoint))
+ endpointDetails, err := apiawsutils.ParseRDSEndpoint(endpoint)
if err != nil {
errors = append(errors, trace.Wrap(err))
continue
@@ -568,7 +567,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns
}, aws.ToString(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeCustom, endpointDetails.ClusterCustomEndpointName),
types.DatabaseSpecV3{
Protocol: protocol,
- URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint), aws.ToInt64(cluster.Port)),
+ URI: fmt.Sprintf("%v:%v", endpoint, aws.ToInt32(cluster.Port)),
AWS: *metadata,
// Aurora instances update their certificates upon restart, and thus custom endpoint SAN may not be available right
@@ -588,14 +587,12 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns
return databases, trace.NewAggregate(errors...)
}
-func checkRDSClusterMembers(cluster *rds.DBCluster) (hasWriterInstance, hasReaderInstance bool) {
+func checkRDSClusterMembers(cluster *rdstypes.DBCluster) (hasWriterInstance, hasReaderInstance bool) {
for _, clusterMember := range cluster.DBClusterMembers {
- if clusterMember != nil {
- if aws.ToBool(clusterMember.IsClusterWriter) {
- hasWriterInstance = true
- } else {
- hasReaderInstance = true
- }
+ if aws.ToBool(clusterMember.IsClusterWriter) {
+ hasWriterInstance = true
+ } else {
+ hasReaderInstance = true
}
}
return
@@ -603,7 +600,7 @@ func checkRDSClusterMembers(cluster *rds.DBCluster) (hasWriterInstance, hasReade
// NewDatabasesFromRDSCluster creates all database resources from an RDS Aurora
// cluster.
-func NewDatabasesFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) {
+func NewDatabasesFromRDSCluster(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Databases, error) {
var errors []error
var databases types.Databases
@@ -648,7 +645,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.D
// NewDatabasesFromDocumentDBCluster creates all database resources from a
// DocumentDB cluster.
-func NewDatabasesFromDocumentDBCluster(cluster *rds.DBCluster) (types.Databases, error) {
+func NewDatabasesFromDocumentDBCluster(cluster *rdstypes.DBCluster) (types.Databases, error) {
var errors []error
var databases types.Databases
@@ -682,7 +679,7 @@ func NewDatabasesFromDocumentDBCluster(cluster *rds.DBCluster) (types.Databases,
// NewDatabaseFromDocumentDBClusterEndpoint creates database resource from
// DocumentDB cluster endpoint.
-func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rds.DBCluster) (types.Database, error) {
+func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rdstypes.DBCluster) (types.Database, error) {
endpointType := apiawsutils.DocumentDBClusterEndpoint
metadata, err := MetadataFromDocumentDBCluster(cluster, endpointType)
if err != nil {
@@ -695,14 +692,14 @@ func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rds.DBCluster) (types.Dat
}, aws.ToString(cluster.DBClusterIdentifier)),
types.DatabaseSpecV3{
Protocol: types.DatabaseProtocolMongoDB,
- URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt64(cluster.Port)),
+ URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt32(cluster.Port)),
AWS: *metadata,
})
}
// NewDatabaseFromDocumentDBReaderEndpoint creates database resource from
// DocumentDB reader endpoint.
-func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rds.DBCluster) (types.Database, error) {
+func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rdstypes.DBCluster) (types.Database, error) {
endpointType := apiawsutils.DocumentDBClusterReaderEndpoint
metadata, err := MetadataFromDocumentDBCluster(cluster, endpointType)
if err != nil {
@@ -715,13 +712,13 @@ func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rds.DBCluster) (types.Data
}, aws.ToString(cluster.DBClusterIdentifier), endpointType),
types.DatabaseSpecV3{
Protocol: types.DatabaseProtocolMongoDB,
- URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt64(cluster.Port)),
+ URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt32(cluster.Port)),
AWS: *metadata,
})
}
// NewDatabaseFromRDSProxy creates database resource from RDS Proxy.
-func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, tags []*rds.Tag) (types.Database, error) {
+func NewDatabaseFromRDSProxy(dbProxy *rdstypes.DBProxy, tags []rdstypes.Tag) (types.Database, error) {
metadata, err := MetadataFromRDSProxy(dbProxy)
if err != nil {
return nil, trace.Wrap(err)
@@ -744,7 +741,7 @@ func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, tags []*rds.Tag) (types.Datab
// NewDatabaseFromRDSProxyCustomEndpoint creates database resource from RDS
// Proxy custom endpoint.
-func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint, tags []*rds.Tag) (types.Database, error) {
+func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint, tags []rdstypes.Tag) (types.Database, error) {
metadata, err := MetadataFromRDSProxyCustomEndpoint(dbProxy, customEndpoint)
if err != nil {
return nil, trace.Wrap(err)
@@ -1045,7 +1042,7 @@ func NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.E
}
// MetadataFromRDSInstance creates AWS metadata from the provided RDS instance.
-func MetadataFromRDSInstance(rdsInstance *rds.DBInstance) (*types.AWS, error) {
+func MetadataFromRDSInstance(rdsInstance *rdstypes.DBInstance) (*types.AWS, error) {
parsedARN, err := arn.Parse(aws.ToString(rdsInstance.DBInstanceArn))
if err != nil {
return nil, trace.Wrap(err)
@@ -1063,7 +1060,7 @@ func MetadataFromRDSInstance(rdsInstance *rds.DBInstance) (*types.AWS, error) {
}
// MetadataFromRDSCluster creates AWS metadata from the provided RDS cluster.
-func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) {
+func MetadataFromRDSCluster(rdsCluster *rdstypes.DBCluster) (*types.AWS, error) {
parsedARN, err := arn.Parse(aws.ToString(rdsCluster.DBClusterArn))
if err != nil {
return nil, trace.Wrap(err)
@@ -1081,7 +1078,7 @@ func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) {
// MetadataFromDocumentDBCluster creates AWS metadata from the provided
// DocumentDB cluster.
-func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string) (*types.AWS, error) {
+func MetadataFromDocumentDBCluster(cluster *rdstypes.DBCluster, endpointType string) (*types.AWS, error) {
parsedARN, err := arn.Parse(aws.ToString(cluster.DBClusterArn))
if err != nil {
return nil, trace.Wrap(err)
@@ -1097,13 +1094,13 @@ func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string)
}
// MetadataFromRDSProxy creates AWS metadata from the provided RDS Proxy.
-func MetadataFromRDSProxy(rdsProxy *rds.DBProxy) (*types.AWS, error) {
+func MetadataFromRDSProxy(rdsProxy *rdstypes.DBProxy) (*types.AWS, error) {
parsedARN, err := arn.Parse(aws.ToString(rdsProxy.DBProxyArn))
if err != nil {
return nil, trace.Wrap(err)
}
- // rds.DBProxy has no resource ID attribute. The resource ID can be found
+ // rdstypes.DBProxy has no resource ID attribute. The resource ID can be found
// in the ARN, e.g.:
//
// arn:aws:rds:ca-central-1:123456789012:db-proxy:prx-xxxyyyzzz
@@ -1127,7 +1124,7 @@ func MetadataFromRDSProxy(rdsProxy *rds.DBProxy) (*types.AWS, error) {
// MetadataFromRDSProxyCustomEndpoint creates AWS metadata from the provided
// RDS Proxy custom endpoint.
-func MetadataFromRDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint) (*types.AWS, error) {
+func MetadataFromRDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint) (*types.AWS, error) {
// Using resource ID from the default proxy for IAM policies to gain the
// RDS connection access.
metadata, err := MetadataFromRDSProxy(rdsProxy)
@@ -1323,12 +1320,12 @@ func rdsEngineToProtocol(engine string) (string, error) {
// rdsEngineFamilyToProtocolAndPort converts RDS engine family to the database protocol and port.
func rdsEngineFamilyToProtocolAndPort(engineFamily string) (string, int, error) {
- switch engineFamily {
- case rds.EngineFamilyMysql:
+ switch rdstypes.EngineFamily(engineFamily) {
+ case rdstypes.EngineFamilyMysql:
return defaults.ProtocolMySQL, services.RDSProxyMySQLPort, nil
- case rds.EngineFamilyPostgresql:
+ case rdstypes.EngineFamilyPostgresql:
return defaults.ProtocolPostgres, services.RDSProxyPostgresPort, nil
- case rds.EngineFamilySqlserver:
+ case rdstypes.EngineFamilySqlserver:
return defaults.ProtocolSQLServer, services.RDSProxySQLServerPort, nil
}
return "", 0, trace.BadParameter("unknown RDS engine family type %q", engineFamily)
@@ -1421,7 +1418,7 @@ func labelsFromAzurePostgresFlexServer(server *armpostgresqlflexibleservers.Serv
}
// labelsFromRDSInstance creates database labels for the provided RDS instance.
-func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[string]string {
+func labelsFromRDSInstance(rdsInstance *rdstypes.DBInstance, meta *types.AWS) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.ToString(rdsInstance.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsInstance.EngineVersion)
@@ -1433,7 +1430,7 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str
}
// labelsFromRDSCluster creates database labels for the provided RDS cluster.
-func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType string, memberInstances []*rds.DBInstance) map[string]string {
+func labelsFromRDSCluster(rdsCluster *rdstypes.DBCluster, meta *types.AWS, endpointType string, memberInstances []rdstypes.DBInstance) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.ToString(rdsCluster.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsCluster.EngineVersion)
@@ -1444,7 +1441,7 @@ func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointTy
return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList))
}
-func labelsFromDocumentDBCluster(cluster *rds.DBCluster, meta *types.AWS, endpointType string) map[string]string {
+func labelsFromDocumentDBCluster(cluster *rdstypes.DBCluster, meta *types.AWS, endpointType string) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.ToString(cluster.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.ToString(cluster.EngineVersion)
@@ -1453,8 +1450,8 @@ func labelsFromDocumentDBCluster(cluster *rds.DBCluster, meta *types.AWS, endpoi
}
// labelsFromRDSProxy creates database labels for the provided RDS Proxy.
-func labelsFromRDSProxy(rdsProxy *rds.DBProxy, meta *types.AWS, tags []*rds.Tag) map[string]string {
- // rds.DBProxy has no TagList.
+func labelsFromRDSProxy(rdsProxy *rdstypes.DBProxy, meta *types.AWS, tags []rdstypes.Tag) map[string]string {
+ // rdstypes.DBProxy has no TagList.
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelVPCID] = aws.ToString(rdsProxy.VpcId)
labels[types.DiscoveryLabelEngine] = aws.ToString(rdsProxy.EngineFamily)
@@ -1463,9 +1460,9 @@ func labelsFromRDSProxy(rdsProxy *rds.DBProxy, meta *types.AWS, tags []*rds.Tag)
// labelsFromRDSProxyCustomEndpoint creates database labels for the provided
// RDS Proxy custom endpoint.
-func labelsFromRDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint, meta *types.AWS, tags []*rds.Tag) map[string]string {
+func labelsFromRDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint, meta *types.AWS, tags []rdstypes.Tag) map[string]string {
labels := labelsFromRDSProxy(rdsProxy, meta, tags)
- labels[types.DiscoveryLabelEndpointType] = aws.ToString(customEndpoint.TargetRole)
+ labels[types.DiscoveryLabelEndpointType] = string(customEndpoint.TargetRole)
return labels
}
diff --git a/lib/srv/discovery/common/database_test.go b/lib/srv/discovery/common/database_test.go
index ab2b45fff24bc..891c31a18bc13 100644
--- a/lib/srv/discovery/common/database_test.go
+++ b/lib/srv/discovery/common/database_test.go
@@ -28,11 +28,10 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql"
"github.com/aws/aws-sdk-go-v2/aws"
- rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/memorydb"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/redshiftserverless"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
@@ -217,7 +216,7 @@ func TestDatabaseFromAzureRedisEnterprise(t *testing.T) {
// TestDatabaseFromRDSInstance tests converting an RDS instance to a database resource.
func TestDatabaseFromRDSInstance(t *testing.T) {
- instance := &rds.DBInstance{
+ instance := &rdstypes.DBInstance{
DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"),
DBInstanceIdentifier: aws.String("instance-1"),
DBClusterIdentifier: aws.String("cluster-1"),
@@ -225,11 +224,11 @@ func TestDatabaseFromRDSInstance(t *testing.T) {
IAMDatabaseAuthenticationEnabled: aws.Bool(true),
Engine: aws.String(services.RDSEnginePostgres),
EngineVersion: aws.String("13.0"),
- Endpoint: &rds.Endpoint{
+ Endpoint: &rdstypes.Endpoint{
Address: aws.String("localhost"),
- Port: aws.Int64(5432),
+ Port: aws.Int32(5432),
},
- TagList: []*rds.Tag{{
+ TagList: []rdstypes.Tag{{
Key: aws.String("key"),
Value: aws.String("val"),
}},
@@ -268,7 +267,7 @@ func TestDatabaseFromRDSInstance(t *testing.T) {
// TestDatabaseFromRDSV2Instance tests converting an RDS instance (from aws sdk v2/rds) to a database resource.
func TestDatabaseFromRDSV2Instance(t *testing.T) {
- instance := &rdsTypesV2.DBInstance{
+ instance := &rdstypes.DBInstance{
DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"),
DBInstanceIdentifier: aws.String("instance-1"),
DBClusterIdentifier: aws.String("cluster-1"),
@@ -277,16 +276,16 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) {
IAMDatabaseAuthenticationEnabled: aws.Bool(true),
Engine: aws.String(services.RDSEnginePostgres),
EngineVersion: aws.String("13.0"),
- Endpoint: &rdsTypesV2.Endpoint{
+ Endpoint: &rdstypes.Endpoint{
Address: aws.String("localhost"),
Port: aws.Int32(5432),
},
- TagList: []rdsTypesV2.Tag{{
+ TagList: []rdstypes.Tag{{
Key: aws.String("key"),
Value: aws.String("val"),
}},
- DBSubnetGroup: &rdsTypesV2.DBSubnetGroup{
- Subnets: []rdsTypesV2.Subnet{
+ DBSubnetGroup: &rdstypes.DBSubnetGroup{
+ Subnets: []rdstypes.Subnet{
{SubnetIdentifier: aws.String("")},
{SubnetIdentifier: aws.String("subnet-1234567890abcdef0")},
{SubnetIdentifier: aws.String("subnet-1234567890abcdef1")},
@@ -294,7 +293,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) {
},
VpcId: aws.String("vpc-asd"),
},
- VpcSecurityGroups: []rdsTypesV2.VpcSecurityGroupMembership{
+ VpcSecurityGroups: []rdstypes.VpcSecurityGroupMembership{
{VpcSecurityGroupId: aws.String("")},
{VpcSecurityGroupId: aws.String("sg-1")},
{VpcSecurityGroupId: aws.String("sg-2")},
@@ -348,7 +347,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) {
newName := "override-1"
instance := instance
instance.TagList = append(instance.TagList,
- rdsTypesV2.Tag{
+ rdstypes.Tag{
Key: aws.String(overrideLabel),
Value: aws.String(newName),
},
@@ -365,7 +364,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) {
// TestDatabaseFromRDSInstance tests converting an RDS instance to a database resource.
func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) {
for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels {
- instance := &rds.DBInstance{
+ instance := &rdstypes.DBInstance{
DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"),
DBInstanceIdentifier: aws.String("instance-1"),
DBClusterIdentifier: aws.String("cluster-1"),
@@ -373,11 +372,11 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) {
IAMDatabaseAuthenticationEnabled: aws.Bool(true),
Engine: aws.String(services.RDSEnginePostgres),
EngineVersion: aws.String("13.0"),
- Endpoint: &rds.Endpoint{
+ Endpoint: &rdstypes.Endpoint{
Address: aws.String("localhost"),
- Port: aws.Int64(5432),
+ Port: aws.Int32(5432),
},
- TagList: []*rds.Tag{
+ TagList: []rdstypes.Tag{
{Key: aws.String("key"), Value: aws.String("val")},
{Key: aws.String(overrideLabel), Value: aws.String("override-1")},
},
@@ -421,8 +420,8 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) {
// TestDatabaseFromRDSCluster tests converting an RDS cluster to a database resource.
func TestDatabaseFromRDSCluster(t *testing.T) {
vpcid := uuid.NewString()
- dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String(vpcid)}}}
- cluster := &rds.DBCluster{
+ dbInstanceMembers := []rdstypes.DBInstance{{DBSubnetGroup: &rdstypes.DBSubnetGroup{VpcId: aws.String(vpcid)}}}
+ cluster := &rdstypes.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"),
DBClusterIdentifier: aws.String("cluster-1"),
DbClusterResourceId: aws.String("resource-1"),
@@ -431,12 +430,12 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
EngineVersion: aws.String("8.0.0"),
Endpoint: aws.String("localhost"),
ReaderEndpoint: aws.String("reader.host"),
- Port: aws.Int64(3306),
- CustomEndpoints: []*string{
- aws.String("myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com"),
- aws.String("myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com"),
+ Port: aws.Int32(3306),
+ CustomEndpoints: []string{
+ "myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com",
+ "myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com",
},
- TagList: []*rds.Tag{{
+ TagList: []rdstypes.Tag{{
Key: aws.String("key"),
Value: aws.String("val"),
}},
@@ -549,9 +548,9 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
t.Run("bad custom endpoints ", func(t *testing.T) {
badCluster := *cluster
- badCluster.CustomEndpoints = []*string{
- aws.String("badendpoint1"),
- aws.String("badendpoint2"),
+ badCluster.CustomEndpoints = []string{
+ "badendpoint1",
+ "badendpoint2",
}
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers)
require.Error(t, err)
@@ -561,7 +560,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
// TestDatabaseFromRDSV2Cluster tests converting an RDS cluster to a database resource.
// It uses the V2 of the aws sdk.
func TestDatabaseFromRDSV2Cluster(t *testing.T) {
- cluster := &rdsTypesV2.DBCluster{
+ cluster := &rdstypes.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"),
DBClusterIdentifier: aws.String("cluster-1"),
DbClusterResourceId: aws.String("resource-1"),
@@ -572,7 +571,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {
Endpoint: aws.String("localhost"),
ReaderEndpoint: aws.String("reader.host"),
Port: aws.Int32(3306),
- VpcSecurityGroups: []rdsTypesV2.VpcSecurityGroupMembership{
+ VpcSecurityGroups: []rdstypes.VpcSecurityGroupMembership{
{VpcSecurityGroupId: aws.String("")},
{VpcSecurityGroupId: aws.String("sg-1")},
{VpcSecurityGroupId: aws.String("sg-2")},
@@ -581,7 +580,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {
"myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com",
"myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com",
},
- TagList: []rdsTypesV2.Tag{{
+ TagList: []rdstypes.Tag{{
Key: aws.String("key"),
Value: aws.String("val"),
}},
@@ -630,7 +629,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {
newName := "override-1"
cluster.TagList = append(cluster.TagList,
- rdsTypesV2.Tag{
+ rdstypes.Tag{
Key: aws.String(overrideLabel),
Value: aws.String(newName),
},
@@ -645,10 +644,10 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {
})
t.Run("DB Cluster uses network information from DB Instance when available", func(t *testing.T) {
- instance := &rdsTypesV2.DBInstance{
- DBSubnetGroup: &rdsTypesV2.DBSubnetGroup{
+ instance := &rdstypes.DBInstance{
+ DBSubnetGroup: &rdstypes.DBSubnetGroup{
VpcId: aws.String("vpc-123"),
- Subnets: []rdsTypesV2.Subnet{
+ Subnets: []rdstypes.Subnet{
{SubnetIdentifier: aws.String("subnet-123")},
{SubnetIdentifier: aws.String("subnet-456")},
},
@@ -699,9 +698,9 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {
// TestDatabaseFromRDSClusterNameOverride tests converting an RDS cluster to a database resource with overridden name.
func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
- dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String("vpc-123")}}}
+ dbInstanceMembers := []rdstypes.DBInstance{{DBSubnetGroup: &rdstypes.DBSubnetGroup{VpcId: aws.String("vpc-123")}}}
for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels {
- cluster := &rds.DBCluster{
+ cluster := &rdstypes.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"),
DBClusterIdentifier: aws.String("cluster-1"),
DbClusterResourceId: aws.String("resource-1"),
@@ -710,12 +709,12 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
EngineVersion: aws.String("8.0.0"),
Endpoint: aws.String("localhost"),
ReaderEndpoint: aws.String("reader.host"),
- Port: aws.Int64(3306),
- CustomEndpoints: []*string{
- aws.String("myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com"),
- aws.String("myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com"),
+ Port: aws.Int32(3306),
+ CustomEndpoints: []string{
+ "myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com",
+ "myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com",
},
- TagList: []*rds.Tag{
+ TagList: []rdstypes.Tag{
{Key: aws.String("key"), Value: aws.String("val")},
{Key: aws.String(overrideLabel), Value: aws.String("mycluster-2")},
},
@@ -831,9 +830,9 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
t.Run("bad custom endpoints ", func(t *testing.T) {
badCluster := *cluster
- badCluster.CustomEndpoints = []*string{
- aws.String("badendpoint1"),
- aws.String("badendpoint2"),
+ badCluster.CustomEndpoints = []string{
+ "badendpoint1",
+ "badendpoint2",
}
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers)
require.Error(t, err)
@@ -896,7 +895,7 @@ func TestNewDatabasesFromDocumentDBCluster(t *testing.T) {
tests := []struct {
name string
- inputCluster *rds.DBCluster
+ inputCluster *rdstypes.DBCluster
wantDatabases types.Databases
}{
{
@@ -929,26 +928,26 @@ func TestDatabaseFromRDSProxy(t *testing.T) {
}{
{
desc: "mysql",
- engineFamily: rds.EngineFamilyMysql,
+ engineFamily: string(rdstypes.EngineFamilyMysql),
wantProtocol: "mysql",
wantPort: 3306,
},
{
desc: "postgres",
- engineFamily: rds.EngineFamilyPostgresql,
+ engineFamily: string(rdstypes.EngineFamilyPostgresql),
wantProtocol: "postgres",
wantPort: 5432,
},
{
desc: "sqlserver",
- engineFamily: rds.EngineFamilySqlserver,
+ engineFamily: string(rdstypes.EngineFamilySqlserver),
wantProtocol: "sqlserver",
wantPort: 1433,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
- dbProxy := &rds.DBProxy{
+ dbProxy := &rdstypes.DBProxy{
DBProxyArn: aws.String("arn:aws:rds:ca-central-1:123456789012:db-proxy:prx-abcdef"),
DBProxyName: aws.String("testproxy"),
EngineFamily: aws.String(test.engineFamily),
@@ -956,15 +955,15 @@ func TestDatabaseFromRDSProxy(t *testing.T) {
VpcId: aws.String("test-vpc-id"),
}
- dbProxyEndpoint := &rds.DBProxyEndpoint{
+ dbProxyEndpoint := &rdstypes.DBProxyEndpoint{
Endpoint: aws.String("custom.proxy.rds.test"),
DBProxyEndpointName: aws.String("custom"),
DBProxyName: aws.String("testproxy"),
DBProxyEndpointArn: aws.String("arn:aws:rds:ca-central-1:123456789012:db-proxy-endpoint:prx-endpoint-abcdef"),
- TargetRole: aws.String(rds.DBProxyEndpointTargetRoleReadOnly),
+ TargetRole: rdstypes.DBProxyEndpointTargetRoleReadOnly,
}
- tags := []*rds.Tag{{
+ tags := []rdstypes.Tag{{
Key: aws.String("key"),
Value: aws.String("val"),
}}
@@ -1059,7 +1058,7 @@ func TestAuroraMySQLVersion(t *testing.T) {
}
for _, test := range tests {
t.Run(test.engineVersion, func(t *testing.T) {
- require.Equal(t, test.expectedMySQLVersion, libcloudaws.AuroraMySQLVersion(&rds.DBCluster{EngineVersion: aws.String(test.engineVersion)}))
+ require.Equal(t, test.expectedMySQLVersion, libcloudaws.AuroraMySQLVersion(&rdstypes.DBCluster{EngineVersion: aws.String(test.engineVersion)}))
})
}
}
@@ -1099,7 +1098,7 @@ func TestIsRDSClusterSupported(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- cluster := &rds.DBCluster{
+ cluster := &rdstypes.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:test"),
DBClusterIdentifier: aws.String(test.name),
DbClusterResourceId: aws.String(uuid.New().String()),
@@ -1149,7 +1148,7 @@ func TestIsRDSInstanceSupported(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- cluster := &rds.DBInstance{
+ cluster := &rdstypes.DBInstance{
DBInstanceArn: aws.String("arn:aws:rds:us-east-1:123456789012:instance:test"),
DBClusterIdentifier: aws.String(test.name),
DbiResourceId: aws.String(uuid.New().String()),
diff --git a/lib/srv/discovery/common/kubernetes.go b/lib/srv/discovery/common/kubernetes.go
index 9c383a6213fda..1bddd210493da 100644
--- a/lib/srv/discovery/common/kubernetes.go
+++ b/lib/srv/discovery/common/kubernetes.go
@@ -24,7 +24,6 @@ import (
"strings"
"github.com/aws/aws-sdk-go-v2/aws/arn"
- "github.com/aws/aws-sdk-go/aws"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types"
@@ -40,7 +39,7 @@ func setAWSKubeName(meta types.Metadata, firstNamePart string, extraNameParts ..
}
// NewKubeClusterFromAWSEKS creates a kube_cluster resource from an EKS cluster.
-func NewKubeClusterFromAWSEKS(clusterName, clusterArn string, tags map[string]*string) (types.KubeCluster, error) {
+func NewKubeClusterFromAWSEKS(clusterName, clusterArn string, tags map[string]string) (types.KubeCluster, error) {
parsedARN, err := arn.Parse(clusterArn)
if err != nil {
return nil, trace.Wrap(err)
@@ -64,7 +63,7 @@ func NewKubeClusterFromAWSEKS(clusterName, clusterArn string, tags map[string]*s
}
// labelsFromAWSKubeClusterTags creates kube cluster labels.
-func labelsFromAWSKubeClusterTags(tags map[string]*string, parsedARN arn.ARN) map[string]string {
+func labelsFromAWSKubeClusterTags(tags map[string]string, parsedARN arn.ARN) map[string]string {
labels := awsEKSTagsToLabels(tags)
labels[types.CloudLabel] = types.CloudAWS
labels[types.DiscoveryLabelRegion] = parsedARN.Region
@@ -74,11 +73,11 @@ func labelsFromAWSKubeClusterTags(tags map[string]*string, parsedARN arn.ARN) ma
}
// awsEKSTagsToLabels converts AWS tags to a labels map.
-func awsEKSTagsToLabels(tags map[string]*string) map[string]string {
+func awsEKSTagsToLabels(tags map[string]string) map[string]string {
labels := make(map[string]string)
for key, val := range tags {
if types.IsValidLabelKey(key) {
- labels[key] = aws.StringValue(val)
+ labels[key] = val
} else {
slog.DebugContext(context.Background(), "Skipping EKS tag that is not a valid label key", "tag", key)
}
diff --git a/lib/srv/discovery/common/kubernetes_test.go b/lib/srv/discovery/common/kubernetes_test.go
index b121c624a1e76..bd69eccaa4676 100644
--- a/lib/srv/discovery/common/kubernetes_test.go
+++ b/lib/srv/discovery/common/kubernetes_test.go
@@ -20,8 +20,8 @@ import (
"testing"
"cloud.google.com/go/container/apiv1/containerpb"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/eks"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
@@ -97,16 +97,15 @@ func TestNewKubeClusterFromAWSEKS(t *testing.T) {
})
require.NoError(t, err)
- cluster := &eks.Cluster{
- Name: aws.String("cluster1"),
- Arn: aws.String("arn:aws:eks:eu-west-1:123456789012:cluster/cluster1"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- overrideLabel: aws.String("override-1"),
- "env": aws.String("prod"),
+ cluster := &ekstypes.Cluster{
+ Name: aws.String("cluster1"),
+ Arn: aws.String("arn:aws:eks:eu-west-1:123456789012:cluster/cluster1"),
+ Tags: map[string]string{
+ overrideLabel: "override-1",
+ "env": "prod",
},
}
- actual, err := NewKubeClusterFromAWSEKS(aws.StringValue(cluster.Name), aws.StringValue(cluster.Arn), cluster.Tags)
+ actual, err := NewKubeClusterFromAWSEKS(aws.ToString(cluster.Name), aws.ToString(cluster.Arn), cluster.Tags)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
require.NoError(t, err)
diff --git a/lib/srv/discovery/common/renaming_test.go b/lib/srv/discovery/common/renaming_test.go
index b01825725f672..7bb64f9f01bab 100644
--- a/lib/srv/discovery/common/renaming_test.go
+++ b/lib/srv/discovery/common/renaming_test.go
@@ -27,15 +27,15 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/eks"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
azureutils "github.com/gravitational/teleport/api/utils/azure"
- libcloudaws "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/awstesthelpers"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/cloud/gcp"
"github.com/gravitational/teleport/lib/services"
@@ -365,7 +365,7 @@ func requireOverrideLabelSkipsRenaming(t *testing.T, r types.ResourceWithLabels,
func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel string) types.Database {
t.Helper()
- cluster := &rds.DBCluster{
+ cluster := &rdstypes.DBCluster{
DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:%s:cluster:%v", region, accountID, name)),
DBClusterIdentifier: aws.String("cluster-1"),
DbClusterResourceId: aws.String("resource-1"),
@@ -373,29 +373,29 @@ func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel st
Engine: aws.String("aurora-mysql"),
EngineVersion: aws.String("8.0.0"),
Endpoint: aws.String("localhost"),
- Port: aws.Int64(3306),
- TagList: libcloudaws.LabelsToTags[rds.Tag](map[string]string{
+ Port: aws.Int32(3306),
+ TagList: awstesthelpers.LabelsToRDSTags(map[string]string{
overrideLabel: name,
}),
}
- database, err := NewDatabaseFromRDSCluster(cluster, []*rds.DBInstance{})
+ database, err := NewDatabaseFromRDSCluster(cluster, []rdstypes.DBInstance{})
require.NoError(t, err)
return database
}
func makeRDSInstanceDB(t *testing.T, name, region, accountID, overrideLabel string) types.Database {
t.Helper()
- instance := &rds.DBInstance{
+ instance := &rdstypes.DBInstance{
DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:%s:db:%v", region, accountID, name)),
DBInstanceIdentifier: aws.String(name),
DbiResourceId: aws.String(uuid.New().String()),
Engine: aws.String(services.RDSEnginePostgres),
DBInstanceStatus: aws.String("available"),
- Endpoint: &rds.Endpoint{
+ Endpoint: &rdstypes.Endpoint{
Address: aws.String("localhost"),
- Port: aws.Int64(5432),
+ Port: aws.Int32(5432),
},
- TagList: libcloudaws.LabelsToTags[rds.Tag](map[string]string{
+ TagList: awstesthelpers.LabelsToRDSTags(map[string]string{
overrideLabel: name,
}),
}
@@ -498,12 +498,11 @@ func labelsToAzureTags(labels map[string]string) map[string]*string {
func makeEKSKubeCluster(t *testing.T, name, region, accountID, overrideLabel string) types.KubeCluster {
t.Helper()
- eksCluster := &eks.Cluster{
- Name: aws.String(name),
- Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:%s:cluster/%s", region, accountID, name)),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- overrideLabel: aws.String(name),
+ eksCluster := &ekstypes.Cluster{
+ Name: aws.String(name),
+ Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:%s:cluster/%s", region, accountID, name)),
+ Tags: map[string]string{
+ overrideLabel: name,
},
}
kubeCluster, err := NewKubeClusterFromAWSEKS(aws.StringValue(eksCluster.Name), aws.StringValue(eksCluster.Arn), eksCluster.Tags)
diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go
index f37ba025d2450..047553edeabde 100644
--- a/lib/srv/discovery/discovery.go
+++ b/lib/srv/discovery/discovery.go
@@ -32,8 +32,10 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
"github.com/aws/aws-sdk-go-v2/service/ssm"
ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
@@ -115,10 +117,18 @@ type gcpInstaller interface {
type Config struct {
// CloudClients is an interface for retrieving cloud clients.
CloudClients cloud.Clients
+
+ // AWSFetchersClients gets the AWS clients for the given region for the fetchers.
+ AWSFetchersClients fetchers.AWSClientGetter
+
+ // GetAWSSyncEKSClient gets an AWS EKS client for the given region for fetchers/aws-sync.
+ GetAWSSyncEKSClient aws_sync.EKSClientGetter
+
// AWSConfigProvider provides [aws.Config] for AWS SDK service clients.
AWSConfigProvider awsconfig.Provider
// AWSDatabaseFetcherFactory provides AWS database fetchers
AWSDatabaseFetcherFactory *db.AWSFetcherFactory
+
// GetEC2Client gets an AWS EC2 client for the given region.
GetEC2Client server.EC2ClientGetter
// GetSSMClient gets an AWS SSM client for the given region.
@@ -196,6 +206,23 @@ type AccessGraphConfig struct {
Insecure bool
}
+type awsFetchersClientsGetter struct {
+ awsconfig.Provider
+}
+
+func (f *awsFetchersClientsGetter) GetAWSEKSClient(cfg aws.Config) fetchers.EKSClient {
+ return eks.NewFromConfig(cfg)
+}
+
+func (f *awsFetchersClientsGetter) GetAWSSTSClient(cfg aws.Config) fetchers.STSClient {
+ return sts.NewFromConfig(cfg)
+}
+
+func (f *awsFetchersClientsGetter) GetAWSSTSPresignClient(cfg aws.Config) fetchers.STSPresignClient {
+ stsClient := sts.NewFromConfig(cfg)
+ return sts.NewPresignClient(stsClient)
+}
+
func (c *Config) CheckAndSetDefaults() error {
if c.Matchers.IsEmpty() && c.DiscoveryGroup == "" {
return trace.BadParameter("no matchers or discovery group configured for discovery")
@@ -253,6 +280,20 @@ kubernetes matchers are present.`)
return ec2.NewFromConfig(cfg), nil
}
}
+ if c.AWSFetchersClients == nil {
+ c.AWSFetchersClients = &awsFetchersClientsGetter{
+ Provider: awsconfig.ProviderFunc(c.getAWSConfig),
+ }
+ }
+ if c.GetAWSSyncEKSClient == nil {
+ c.GetAWSSyncEKSClient = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (aws_sync.EKSClient, error) {
+ cfg, err := c.getAWSConfig(ctx, region, opts...)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return eks.NewFromConfig(cfg), nil
+ }
+ }
if c.GetSSMClient == nil {
c.GetSSMClient = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) {
cfg, err := c.getAWSConfig(ctx, region, opts...)
@@ -561,7 +602,7 @@ func (s *Server) initAWSWatchers(matchers []types.AWSMatcher) error {
_, otherMatchers = splitMatchers(otherMatchers, db.IsAWSMatcherType)
// Add non-integration kube fetchers.
- kubeFetchers, err := fetchers.MakeEKSFetchersFromAWSMatchers(s.Log, s.CloudClients, otherMatchers, noDiscoveryConfig)
+ kubeFetchers, err := fetchers.MakeEKSFetchersFromAWSMatchers(s.Log, s.AWSFetchersClients, otherMatchers, noDiscoveryConfig)
if err != nil {
return trace.Wrap(err)
}
@@ -714,12 +755,12 @@ func (s *Server) databaseFetchersFromMatchers(matchers Matchers, discoveryConfig
func (s *Server) kubeFetchersFromMatchers(matchers Matchers, discoveryConfigName string) ([]common.Fetcher, error) {
var result []common.Fetcher
- // AWS
+ // AWS.
awsKubeMatchers, _ := splitMatchers(matchers.AWS, func(matcherType string) bool {
return matcherType == types.AWSMatcherEKS
})
if len(awsKubeMatchers) > 0 {
- eksFetchers, err := fetchers.MakeEKSFetchersFromAWSMatchers(s.Log, s.CloudClients, awsKubeMatchers, discoveryConfigName)
+ eksFetchers, err := fetchers.MakeEKSFetchersFromAWSMatchers(s.Log, s.AWSFetchersClients, awsKubeMatchers, discoveryConfigName)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -1264,7 +1305,6 @@ func (s *Server) filterExistingAzureNodes(instances *server.AzureInstances) erro
_, vmOK := labels[types.VMIDLabel]
return subscriptionOK && vmOK
})
-
if err != nil {
return trace.Wrap(err)
}
@@ -1357,7 +1397,6 @@ func (s *Server) filterExistingGCPNodes(instances *server.GCPInstances) error {
_, nameOK := labels[types.NameLabelDiscovery]
return projectIDOK && zoneOK && nameOK
})
-
if err != nil {
return trace.Wrap(err)
}
diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go
index 865517ba4c33c..2948e10cdb916 100644
--- a/lib/srv/discovery/discovery_test.go
+++ b/lib/srv/discovery/discovery_test.go
@@ -36,18 +36,17 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3"
- awsv2 "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go-v2/service/redshift"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/aws/aws-sdk-go-v2/service/ssm"
ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/request"
- "github.com/aws/aws-sdk-go/service/eks"
- "github.com/aws/aws-sdk-go/service/eks/eksiface"
- "github.com/aws/aws-sdk-go/service/rds"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
@@ -86,6 +85,7 @@ import (
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/discovery/common"
+ "github.com/gravitational/teleport/lib/srv/discovery/fetchers"
"github.com/gravitational/teleport/lib/srv/discovery/fetchers/db"
"github.com/gravitational/teleport/lib/srv/server"
usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport"
@@ -175,10 +175,10 @@ func genEC2Instances(n int) []ec2types.Instance {
var ec2Instances []ec2types.Instance
for _, id := range genEC2InstanceIDs(n) {
ec2Instances = append(ec2Instances, ec2types.Instance{
- InstanceId: awsv2.String(id),
+ InstanceId: aws.String(id),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
@@ -324,11 +324,12 @@ func TestDiscoveryServer(t *testing.T) {
tcs := []struct {
name string
- // presentInstances is a list of servers already present in teleport
+ // presentInstances is a list of servers already present in teleport.
presentInstances []types.Server
foundEC2Instances []ec2types.Instance
ssm *mockSSMClient
emitter *mockEmitter
+ eksClusters []*ekstypes.Cluster
eksEnroller eksClustersEnroller
discoveryConfig *discoveryconfig.DiscoveryConfig
staticMatchers Matchers
@@ -339,14 +340,14 @@ func TestDiscoveryServer(t *testing.T) {
ssmRunError error
}{
{
- name: "no nodes present, 1 found ",
+ name: "no nodes present, 1 found",
presentInstances: []types.Server{},
foundEC2Instances: []ec2types.Instance{
{
- InstanceId: awsv2.String("instance-id-1"),
+ InstanceId: aws.String("instance-id-1"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
@@ -356,7 +357,7 @@ func TestDiscoveryServer(t *testing.T) {
ssm: &mockSSMClient{
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
- CommandId: awsv2.String("command-id-1"),
+ CommandId: aws.String("command-id-1"),
},
},
invokeOutput: &ssm.GetCommandInvocationOutput{
@@ -401,10 +402,10 @@ func TestDiscoveryServer(t *testing.T) {
},
foundEC2Instances: []ec2types.Instance{
{
- InstanceId: awsv2.String("instance-id-1"),
+ InstanceId: aws.String("instance-id-1"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
@@ -414,7 +415,7 @@ func TestDiscoveryServer(t *testing.T) {
ssm: &mockSSMClient{
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
- CommandId: awsv2.String("command-id-1"),
+ CommandId: aws.String("command-id-1"),
},
},
invokeOutput: &ssm.GetCommandInvocationOutput{
@@ -442,10 +443,10 @@ func TestDiscoveryServer(t *testing.T) {
},
foundEC2Instances: []ec2types.Instance{
{
- InstanceId: awsv2.String("instance-id-1"),
+ InstanceId: aws.String("instance-id-1"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
@@ -455,7 +456,7 @@ func TestDiscoveryServer(t *testing.T) {
ssm: &mockSSMClient{
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
- CommandId: awsv2.String("command-id-1"),
+ CommandId: aws.String("command-id-1"),
},
},
invokeOutput: &ssm.GetCommandInvocationOutput{
@@ -474,7 +475,7 @@ func TestDiscoveryServer(t *testing.T) {
ssm: &mockSSMClient{
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
- CommandId: awsv2.String("command-id-1"),
+ CommandId: aws.String("command-id-1"),
},
},
invokeOutput: &ssm.GetCommandInvocationOutput{
@@ -491,10 +492,10 @@ func TestDiscoveryServer(t *testing.T) {
presentInstances: []types.Server{},
foundEC2Instances: []ec2types.Instance{
{
- InstanceId: awsv2.String("instance-id-1"),
+ InstanceId: aws.String("instance-id-1"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
@@ -504,7 +505,7 @@ func TestDiscoveryServer(t *testing.T) {
ssm: &mockSSMClient{
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
- CommandId: awsv2.String("command-id-1"),
+ CommandId: aws.String("command-id-1"),
},
},
invokeOutput: &ssm.GetCommandInvocationOutput{
@@ -538,10 +539,10 @@ func TestDiscoveryServer(t *testing.T) {
presentInstances: []types.Server{},
foundEC2Instances: []ec2types.Instance{
{
- InstanceId: awsv2.String("instance-id-1"),
+ InstanceId: aws.String("instance-id-1"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
@@ -551,7 +552,7 @@ func TestDiscoveryServer(t *testing.T) {
ssm: &mockSSMClient{
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
- CommandId: awsv2.String("command-id-1"),
+ CommandId: aws.String("command-id-1"),
},
},
invokeOutput: &ssm.GetCommandInvocationOutput{
@@ -625,10 +626,10 @@ func TestDiscoveryServer(t *testing.T) {
presentInstances: []types.Server{},
foundEC2Instances: []ec2types.Instance{
{
- InstanceId: awsv2.String("instance-id-1"),
+ InstanceId: aws.String("instance-id-1"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
@@ -638,7 +639,7 @@ func TestDiscoveryServer(t *testing.T) {
ssm: &mockSSMClient{
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
- CommandId: awsv2.String("command-id-1"),
+ CommandId: aws.String("command-id-1"),
},
},
invokeOutput: &ssm.GetCommandInvocationOutput{
@@ -667,7 +668,7 @@ func TestDiscoveryServer(t *testing.T) {
staticMatchers: Matchers{},
discoveryConfig: discoveryConfigForUserTaskEC2Test,
wantInstalledInstances: []string{},
- userTasksDiscoverCheck: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) {
+ userTasksDiscoverCheck: func(t require.TestingT, i1 interface{}, i2 ...interface{}) {
existingTasks, ok := i1.([]*usertasksv1.UserTask)
require.True(t, ok, "failed to get existing tasks: %T", i1)
require.Len(t, existingTasks, 1)
@@ -693,26 +694,21 @@ func TestDiscoveryServer(t *testing.T) {
presentInstances: []types.Server{},
foundEC2Instances: []ec2types.Instance{},
ssm: &mockSSMClient{},
- cloudClients: &cloud.TestCloudClients{
- STS: &mocks.STSClientV1{},
- EKS: &mocks.EKSMock{
- Clusters: []*eks.Cluster{
- {
- Name: aws.String("cluster01"),
- Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster01"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "RunDiscover": aws.String("Please"),
- },
- },
- {
- Name: aws.String("cluster02"),
- Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster02"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "RunDiscover": aws.String("Please"),
- },
- },
+ eksClusters: []*ekstypes.Cluster{
+ {
+ Name: aws.String("cluster01"),
+ Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster01"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "RunDiscover": "Please",
+ },
+ },
+ {
+ Name: aws.String("cluster02"),
+ Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster02"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "RunDiscover": "Please",
},
},
},
@@ -737,7 +733,7 @@ func TestDiscoveryServer(t *testing.T) {
staticMatchers: Matchers{},
discoveryConfig: discoveryConfigForUserTaskEKSTest,
wantInstalledInstances: []string{},
- userTasksDiscoverCheck: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) {
+ userTasksDiscoverCheck: func(t require.TestingT, i1 interface{}, i2 ...interface{}) {
existingTasks, ok := i1.([]*usertasksv1.UserTask)
require.True(t, ok, "failed to get existing tasks: %T", i1)
require.Len(t, existingTasks, 1)
@@ -761,20 +757,21 @@ func TestDiscoveryServer(t *testing.T) {
}
for _, tc := range tcs {
- tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
+ ctx := context.Background()
- ec2Client := &mockEC2Client{output: &ec2.DescribeInstancesOutput{
- Reservations: []ec2types.Reservation{
- {
- OwnerId: awsv2.String("owner"),
- Instances: tc.foundEC2Instances,
+ ec2Client := &mockEC2Client{
+ output: &ec2.DescribeInstancesOutput{
+ Reservations: []ec2types.Reservation{
+ {
+ OwnerId: aws.String("owner"),
+ Instances: tc.foundEC2Instances,
+ },
},
},
- }}
+ }
- ctx := context.Background()
// Create and start test auth server.
testAuthServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
Dir: t.TempDir(),
@@ -782,9 +779,24 @@ func TestDiscoveryServer(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, testAuthServer.Close()) })
+ awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{
+ Name: "my-integration",
+ }, &types.AWSOIDCIntegrationSpecV1{
+ RoleARN: "arn:aws:iam::123456789012:role/teleport",
+ })
+ require.NoError(t, err)
+ testAuthServer.AuthServer.IntegrationsTokenGenerator = &mockIntegrationsTokenGenerator{
+ proxies: nil,
+ integrations: map[string]types.Integration{
+ awsOIDCIntegration.GetName(): awsOIDCIntegration,
+ },
+ }
+
tlsServer, err := testAuthServer.NewTestTLSServer()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, tlsServer.Close()) })
+ _, err = tlsServer.Auth().CreateIntegration(ctx, awsOIDCIntegration)
+ require.NoError(t, err)
// Auth client for discovery service.
identity := auth.TestServerID(types.RoleDiscovery, "hostID")
@@ -816,6 +828,9 @@ func TestDiscoveryServer(t *testing.T) {
eksEnroller = tc.eksEnroller
}
+ fakeConfigProvider := mocks.AWSConfigProvider{
+ OIDCIntegrationClient: tlsServer.Auth(),
+ }
server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{
GetEC2Client: func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) {
return ec2Client, nil
@@ -823,6 +838,11 @@ func TestDiscoveryServer(t *testing.T) {
GetSSMClient: func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) {
return tc.ssm, nil
},
+ AWSConfigProvider: &fakeConfigProvider,
+ AWSFetchersClients: &mockFetchersClients{
+ AWSConfigProvider: fakeConfigProvider,
+ eksClusters: tc.eksClusters,
+ },
ClusterFeatures: func() proto.Features { return proto.Features{} },
KubernetesClient: fake.NewSimpleClientset(),
AccessPoint: getDiscoveryAccessPointWithEKSEnroller(tlsServer.Auth(), authClient, eksEnroller),
@@ -916,20 +936,20 @@ func TestDiscoveryServerConcurrency(t *testing.T) {
output: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
- OwnerId: awsv2.String("123456789012"),
+ OwnerId: aws.String("123456789012"),
Instances: []ec2types.Instance{
{
- InstanceId: awsv2.String("i-123456789012"),
+ InstanceId: aws.String("i-123456789012"),
Tags: []ec2types.Tag{
{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
},
},
- PrivateIpAddress: awsv2.String("172.0.1.2"),
- VpcId: awsv2.String("vpcId"),
- SubnetId: awsv2.String("subnetId"),
- PrivateDnsName: awsv2.String("privateDnsName"),
+ PrivateIpAddress: aws.String("172.0.1.2"),
+ VpcId: aws.String("vpcId"),
+ SubnetId: aws.String("subnetId"),
+ PrivateDnsName: aws.String("privateDnsName"),
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
},
@@ -1212,11 +1232,12 @@ func TestDiscoveryKubeServices(t *testing.T) {
}
func TestDiscoveryInCloudKube(t *testing.T) {
+ t.Parallel()
+
const (
mainDiscoveryGroup = "main"
otherDiscoveryGroup = "other"
)
- t.Parallel()
tcs := []struct {
name string
existingKubeClusters []types.KubeCluster
@@ -1440,15 +1461,11 @@ func TestDiscoveryInCloudKube(t *testing.T) {
}
for _, tc := range tcs {
- tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
- sts := &mocks.STSClientV1{}
testCloudClients := &cloud.TestCloudClients{
- STS: sts,
AzureAKSClient: newPopulatedAKSMock(),
- EKS: newPopulatedEKSMock(),
GCPGKE: newPopulatedGCPMock(),
GCPProjects: newPopulatedGCPProjectsMock(),
}
@@ -1475,7 +1492,7 @@ func TestDiscoveryInCloudKube(t *testing.T) {
err := tlsServer.Auth().CreateKubernetesCluster(ctx, kubeCluster)
require.NoError(t, err)
}
- // we analyze the logs emitted by discovery service to detect clusters that were not updated
+ // We analyze the logs emitted by discovery service to detect clusters that were not updated
// because their state didn't change.
r, w := io.Pipe()
t.Cleanup(func() {
@@ -1506,15 +1523,26 @@ func TestDiscoveryInCloudKube(t *testing.T) {
}
}
}()
+
reporter := &mockUsageReporter{}
tlsServer.Auth().SetUsageReporter(reporter)
+
+ mockedClients := &mockFetchersClients{
+ AWSConfigProvider: mocks.AWSConfigProvider{
+ STSClient: &mocks.STSClient{},
+ OIDCIntegrationClient: newFakeAccessPoint(),
+ },
+ eksClusters: newPopulatedEKSMock().clusters,
+ }
+
discServer, err := New(
authz.ContextWithUser(ctx, identity.I),
&Config{
- CloudClients: testCloudClients,
- ClusterFeatures: func() proto.Features { return proto.Features{} },
- KubernetesClient: fake.NewSimpleClientset(),
- AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient),
+ CloudClients: testCloudClients,
+ AWSFetchersClients: mockedClients,
+ ClusterFeatures: func() proto.Features { return proto.Features{} },
+ KubernetesClient: fake.NewSimpleClientset(),
+ AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient),
Matchers: Matchers{
AWS: tc.awsMatchers,
Azure: tc.azureMatchers,
@@ -1524,12 +1552,9 @@ func TestDiscoveryInCloudKube(t *testing.T) {
Log: logger,
DiscoveryGroup: mainDiscoveryGroup,
})
-
require.NoError(t, err)
- t.Cleanup(func() {
- discServer.Stop()
- })
+ t.Cleanup(discServer.Stop)
go discServer.Start()
clustersNotUpdatedMap := sliceToSet(tc.clustersNotUpdated)
@@ -1562,8 +1587,8 @@ func TestDiscoveryInCloudKube(t *testing.T) {
return len(clustersNotUpdated) == 0 && clustersFoundInAuth
}, 5*time.Second, 200*time.Millisecond)
- require.ElementsMatch(t, tc.expectedAssumedRoles, sts.GetAssumedRoleARNs(), "roles incorrectly assumed")
- require.ElementsMatch(t, tc.expectedExternalIDs, sts.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed")
+ require.ElementsMatch(t, tc.expectedAssumedRoles, mockedClients.STSClient.GetAssumedRoleARNs(), "roles incorrectly assumed")
+ require.ElementsMatch(t, tc.expectedExternalIDs, mockedClients.STSClient.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed")
if tc.wantEvents > 0 {
require.Eventually(t, func() bool {
@@ -1582,14 +1607,15 @@ func TestDiscoveryServer_New(t *testing.T) {
t.Parallel()
testCases := []struct {
desc string
- cloudClients cloud.Clients
+ cloudClients fetchers.AWSClientGetter
matchers Matchers
errAssertion require.ErrorAssertionFunc
discServerAssertion require.ValueAssertionFunc
}{
{
- desc: "no matchers error",
- cloudClients: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}},
+ desc: "no matchers error",
+
+ cloudClients: &mockFetchersClients{},
matchers: Matchers{},
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorIs(t, err, &trace.BadParameterError{Message: "no matchers or discovery group configured for discovery"})
@@ -1597,8 +1623,10 @@ func TestDiscoveryServer_New(t *testing.T) {
discServerAssertion: require.Nil,
},
{
- desc: "success with EKS matcher",
- cloudClients: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}, EKS: &mocks.EKSMock{}},
+ desc: "success with EKS matcher",
+
+ cloudClients: &mockFetchersClients{},
+
matchers: Matchers{
AWS: []types.AWSMatcher{
{
@@ -1621,11 +1649,8 @@ func TestDiscoveryServer_New(t *testing.T) {
},
},
{
- desc: "EKS fetcher is skipped on initialization error (missing region)",
- cloudClients: &cloud.TestCloudClients{
- STS: &mocks.STSClientV1{},
- EKS: &mocks.EKSMock{},
- },
+ desc: "EKS fetcher is skipped on initialization error (missing region)",
+ cloudClients: &mockFetchersClients{},
matchers: Matchers{
AWS: []types.AWSMatcher{
{
@@ -1666,12 +1691,12 @@ func TestDiscoveryServer_New(t *testing.T) {
discServer, err := New(
ctx,
&Config{
- CloudClients: tt.cloudClients,
- ClusterFeatures: func() proto.Features { return proto.Features{} },
- AccessPoint: newFakeAccessPoint(),
- Matchers: tt.matchers,
- Emitter: &mockEmitter{},
- protocolChecker: &noopProtocolChecker{},
+ AWSFetchersClients: tt.cloudClients,
+ ClusterFeatures: func() proto.Features { return proto.Features{} },
+ AccessPoint: newFakeAccessPoint(),
+ Matchers: tt.matchers,
+ Emitter: &mockEmitter{},
+ protocolChecker: &noopProtocolChecker{},
})
tt.errAssertion(t, err)
@@ -1759,28 +1784,33 @@ var aksMockClusters = map[string][]*azure.AKSCluster{
}
type mockEKSAPI struct {
- eksiface.EKSAPI
- clusters []*eks.Cluster
+ fetchers.EKSClient
+ clusters []*ekstypes.Cluster
}
-func (m *mockEKSAPI) ListClustersPagesWithContext(ctx aws.Context, req *eks.ListClustersInput, f func(*eks.ListClustersOutput, bool) bool, _ ...request.Option) error {
- var names []*string
+func (m *mockEKSAPI) ListClusters(ctx context.Context, req *eks.ListClustersInput, _ ...func(*eks.Options)) (*eks.ListClustersOutput, error) {
+ var names []string
for _, cluster := range m.clusters {
- names = append(names, cluster.Name)
+ names = append(names, aws.ToString(cluster.Name))
}
- f(&eks.ListClustersOutput{
- Clusters: names[:len(names)/2],
- }, false)
- f(&eks.ListClustersOutput{
+ // First call, no NextToken. Return first half and a NextToken value.
+ if req.NextToken == nil {
+ return &eks.ListClustersOutput{
+ Clusters: names[:len(names)/2],
+ NextToken: aws.String("next"),
+ }, nil
+ }
+
+ // Second call, we have a NextToken, return the second half.
+ return &eks.ListClustersOutput{
Clusters: names[len(names)/2:],
- }, true)
- return nil
+ }, nil
}
-func (m *mockEKSAPI) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) {
+func (m *mockEKSAPI) DescribeCluster(_ context.Context, req *eks.DescribeClusterInput, _ ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) {
for _, cluster := range m.clusters {
- if aws.StringValue(cluster.Name) == aws.StringValue(req.Name) {
+ if aws.ToString(cluster.Name) == aws.ToString(req.Name) {
return &eks.DescribeClusterOutput{
Cluster: cluster,
}, nil
@@ -1795,48 +1825,70 @@ func newPopulatedEKSMock() *mockEKSAPI {
}
}
-var eksMockClusters = []*eks.Cluster{
+type mockFetchersClients struct {
+ mocks.AWSConfigProvider
+ eksClusters []*ekstypes.Cluster
+}
+
+func (m *mockFetchersClients) GetAWSEKSClient(aws.Config) fetchers.EKSClient {
+ return &mockEKSAPI{
+ clusters: m.eksClusters,
+ }
+}
+
+func (m *mockFetchersClients) GetAWSSTSClient(aws.Config) fetchers.STSClient {
+ if m.AWSConfigProvider.STSClient != nil {
+ return m.AWSConfigProvider.STSClient
+ }
+ return &mocks.STSClient{}
+}
+
+func (m *mockFetchersClients) GetAWSSTSPresignClient(aws.Config) fetchers.STSPresignClient {
+ return nil
+}
+
+var eksMockClusters = []*ekstypes.Cluster{
{
Name: aws.String("eks-cluster1"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("prod"),
- "location": aws.String("eu-west-1"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "prod",
+ "location": "eu-west-1",
},
},
{
Name: aws.String("eks-cluster2"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster2"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("prod"),
- "location": aws.String("eu-west-1"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "prod",
+ "location": "eu-west-1",
},
},
{
Name: aws.String("eks-cluster3"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster3"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("stg"),
- "location": aws.String("eu-west-1"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "stg",
+ "location": "eu-west-1",
},
},
{
Name: aws.String("eks-cluster4"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("stg"),
- "location": aws.String("eu-west-1"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "stg",
+ "location": "eu-west-1",
},
},
}
-func mustConvertEKSToKubeCluster(t *testing.T, eksCluster *eks.Cluster, discoveryParams rewriteDiscoveryLabelsParams) types.KubeCluster {
- cluster, err := common.NewKubeClusterFromAWSEKS(aws.StringValue(eksCluster.Name), aws.StringValue(eksCluster.Arn), eksCluster.Tags)
+func mustConvertEKSToKubeCluster(t *testing.T, eksCluster *ekstypes.Cluster, discoveryParams rewriteDiscoveryLabelsParams) types.KubeCluster {
+ cluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(eksCluster.Name), aws.ToString(eksCluster.Arn), eksCluster.Tags)
require.NoError(t, err)
discoveryParams.matcherType = types.AWSMatcherEKS
rewriteCloudResource(t, cluster, discoveryParams)
@@ -2012,13 +2064,7 @@ func TestDiscoveryDatabase(t *testing.T) {
}
testCloudClients := &cloud.TestCloudClients{
- STS: &mocks.STSClientV1{},
- RDS: &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{awsRDSInstance},
- DBEngineVersions: []*rds.DBEngineVersion{
- {Engine: aws.String(services.RDSEnginePostgres)},
- },
- },
+ STS: &mocks.STSClientV1{},
MemoryDB: &mocks.MemoryDBMock{},
AzureRedis: azure.NewRedisClientByAPI(&azure.ARMRedisMock{
Servers: []*armredis.ResourceInfo{azRedisResource},
@@ -2027,9 +2073,6 @@ func TestDiscoveryDatabase(t *testing.T) {
&azure.ARMRedisEnterpriseClusterMock{},
&azure.ARMRedisEnterpriseDatabaseMock{},
),
- EKS: &mocks.EKSMock{
- Clusters: []*eks.Cluster{eksAWSResource},
- },
}
tcs := []struct {
@@ -2303,7 +2346,6 @@ func TestDiscoveryDatabase(t *testing.T) {
}
for _, tc := range tcs {
- tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
@@ -2360,9 +2402,17 @@ func TestDiscoveryDatabase(t *testing.T) {
dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{
AWSConfigProvider: fakeConfigProvider,
CloudClients: testCloudClients,
- RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{
- Clusters: []redshifttypes.Cluster{*awsRedshiftResource},
- }),
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*awsRDSInstance},
+ DBEngineVersions: []rdstypes.DBEngineVersion{
+ {Engine: aws.String(services.RDSEnginePostgres)},
+ },
+ },
+ redshiftClient: &mocks.RedshiftClient{
+ Clusters: []redshifttypes.Cluster{*awsRedshiftResource},
+ },
+ },
})
require.NoError(t, err)
@@ -2370,12 +2420,16 @@ func TestDiscoveryDatabase(t *testing.T) {
authz.ContextWithUser(ctx, identity.I),
&Config{
IntegrationOnlyCredentials: integrationOnlyCredential,
- CloudClients: testCloudClients,
- AWSDatabaseFetcherFactory: dbFetcherFactory,
- AWSConfigProvider: fakeConfigProvider,
- ClusterFeatures: func() proto.Features { return proto.Features{} },
- KubernetesClient: fake.NewSimpleClientset(),
- AccessPoint: accessPoint,
+ AWSFetchersClients: &mockFetchersClients{
+ AWSConfigProvider: *fakeConfigProvider,
+ eksClusters: []*ekstypes.Cluster{eksAWSResource},
+ },
+ CloudClients: testCloudClients,
+ ClusterFeatures: func() proto.Features { return proto.Features{} },
+ KubernetesClient: fake.NewSimpleClientset(),
+ AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient),
+ AWSDatabaseFetcherFactory: dbFetcherFactory,
+ AWSConfigProvider: fakeConfigProvider,
Matchers: Matchers{
AWS: tc.awsMatchers,
Azure: tc.azureMatchers,
@@ -2420,7 +2474,7 @@ func TestDiscoveryDatabase(t *testing.T) {
cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"),
))
case <-time.After(time.Second):
- t.Fatal("Didn't receive reconcile event after 1s")
+ require.FailNow(t, "Didn't receive reconcile event after 1s")
}
if tc.wantEvents > 0 {
@@ -2452,15 +2506,25 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) {
awsRDSInstance, awsRDSDB := makeRDSInstance(t, "aws-rds", "us-west-1", rewriteDiscoveryLabelsParams{discoveryConfigName: dc2Name, discoveryGroup: mainDiscoveryGroup})
+ fakeConfigProvider := &mocks.AWSConfigProvider{
+ STSClient: &mocks.STSClient{},
+ }
testCloudClients := &cloud.TestCloudClients{
- STS: &mocks.STSClientV1{},
- RDS: &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{awsRDSInstance},
- DBEngineVersions: []*rds.DBEngineVersion{
- {Engine: aws.String(services.RDSEnginePostgres)},
+ STS: &fakeConfigProvider.STSClient.STSClientV1,
+ }
+ dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{
+ AWSConfigProvider: fakeConfigProvider,
+ CloudClients: testCloudClients,
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*awsRDSInstance},
+ DBEngineVersions: []rdstypes.DBEngineVersion{
+ {Engine: aws.String(services.RDSEnginePostgres)},
+ },
},
},
- }
+ })
+ require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
@@ -2488,14 +2552,16 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) {
srv, err := New(
authz.ContextWithUser(ctx, identity.I),
&Config{
- CloudClients: testCloudClients,
- ClusterFeatures: func() proto.Features { return proto.Features{} },
- KubernetesClient: fake.NewSimpleClientset(),
- AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient),
- Matchers: Matchers{},
- Emitter: authClient,
- DiscoveryGroup: mainDiscoveryGroup,
- clock: clock,
+ AWSConfigProvider: fakeConfigProvider,
+ AWSDatabaseFetcherFactory: dbFetcherFactory,
+ CloudClients: testCloudClients,
+ ClusterFeatures: func() proto.Features { return proto.Features{} },
+ KubernetesClient: fake.NewSimpleClientset(),
+ AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient),
+ Matchers: Matchers{},
+ Emitter: authClient,
+ DiscoveryGroup: mainDiscoveryGroup,
+ clock: clock,
})
require.NoError(t, err)
@@ -2601,33 +2667,33 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) {
})
}
-func makeEKSCluster(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*eks.Cluster, types.KubeCluster) {
+func makeEKSCluster(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*ekstypes.Cluster, types.KubeCluster) {
t.Helper()
- eksAWSCluster := &eks.Cluster{
+ eksAWSCluster := &ekstypes.Cluster{
Name: aws.String(name),
Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:123456789012:cluster/%s", region, name)),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("prod"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "prod",
},
}
- actual, err := common.NewKubeClusterFromAWSEKS(aws.StringValue(eksAWSCluster.Name), aws.StringValue(eksAWSCluster.Arn), eksAWSCluster.Tags)
+ actual, err := common.NewKubeClusterFromAWSEKS(aws.ToString(eksAWSCluster.Name), aws.ToString(eksAWSCluster.Arn), eksAWSCluster.Tags)
require.NoError(t, err)
discoveryParams.matcherType = types.AWSMatcherEKS
rewriteCloudResource(t, actual, discoveryParams)
return eksAWSCluster, actual
}
-func makeRDSInstance(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*rds.DBInstance, types.Database) {
- instance := &rds.DBInstance{
+func makeRDSInstance(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*rdstypes.DBInstance, types.Database) {
+ instance := &rdstypes.DBInstance{
DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)),
DBInstanceIdentifier: aws.String(name),
DbiResourceId: aws.String(uuid.New().String()),
Engine: aws.String(services.RDSEnginePostgres),
DBInstanceStatus: aws.String("available"),
- Endpoint: &rds.Endpoint{
+ Endpoint: &rdstypes.Endpoint{
Address: aws.String("localhost"),
- Port: aws.Int64(5432),
+ Port: aws.Int32(5432),
},
}
database, err := common.NewDatabaseFromRDSInstance(instance)
@@ -2986,6 +3052,7 @@ func (m *mockGCPClient) getVMSForProject(projectID string) []*gcpimds.Instance {
}
return vms
}
+
func (m *mockGCPClient) ListInstances(_ context.Context, projectID, _ string) ([]*gcpimds.Instance, error) {
return m.getVMSForProject(projectID), nil
}
@@ -3696,8 +3763,15 @@ func newPopulatedGCPProjectsMock() *mockProjectsAPI {
}
}
-func newFakeRedshiftClientProvider(c redshift.DescribeClustersAPIClient) db.RedshiftClientProviderFunc {
- return func(cfg awsv2.Config, optFns ...func(*redshift.Options)) db.RedshiftClient {
- return c
- }
+type fakeAWSClients struct {
+ rdsClient db.RDSClient
+ redshiftClient db.RedshiftClient
+}
+
+func (f fakeAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) db.RDSClient {
+ return f.rdsClient
+}
+
+func (f fakeAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) db.RedshiftClient {
+ return f.redshiftClient
}
diff --git a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go
index 2a7e928370091..a65742fc38856 100644
--- a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go
+++ b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go
@@ -24,9 +24,9 @@ import (
"sync"
"time"
- awsv2 "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
- "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gravitational/trace"
@@ -45,8 +45,12 @@ const pageSize int64 = 500
// Config is the configuration for the AWS fetcher.
type Config struct {
+ // AWSConfigProvider provides [aws.Config] for AWS SDK service clients.
+ AWSConfigProvider awsconfig.Provider
// CloudClients is the cloud clients to use when fetching AWS resources.
CloudClients cloud.Clients
+ // GetEKSClient gets an AWS EKS client for the given region.
+ GetEKSClient EKSClientGetter
// GetEC2Client gets an AWS EC2 client for the given region.
GetEC2Client server.EC2ClientGetter
// AccountID is the AWS account ID to use when fetching resources.
@@ -59,6 +63,32 @@ type Config struct {
Integration string
// DiscoveryConfigName if set, will be used to report the Discovery Config Status to the Auth Server.
DiscoveryConfigName string
+
+ // awsClients provides AWS SDK clients.
+ awsClients awsClientProvider
+}
+
+func (c *Config) CheckAndSetDefaults() error {
+ if c.AWSConfigProvider == nil {
+ return trace.BadParameter("missing AWSConfigProvider")
+ }
+
+ if c.awsClients == nil {
+ c.awsClients = defaultAWSClients{}
+ }
+ return nil
+}
+
+// awsClientProvider provides AWS service API clients.
+type awsClientProvider interface {
+ // getRDSClient provides an [RDSClient].
+ getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient
+}
+
+type defaultAWSClients struct{}
+
+func (defaultAWSClients) getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient {
+ return rds.NewFromConfig(cfg, optFns...)
}
// AssumeRole is the configuration for assuming an AWS role.
@@ -182,6 +212,9 @@ func (r *Resources) UsageReport(numberAccounts int) *usageeventsv1.AccessGraphAW
// NewAWSFetcher creates a new AWS fetcher.
func NewAWSFetcher(ctx context.Context, cfg Config) (AWSSync, error) {
+ if err := cfg.CheckAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
a := &awsFetcher{
Config: cfg,
lastResult: &Resources{},
@@ -335,7 +368,7 @@ func (a *awsFetcher) getAWSV2Options() []awsconfig.OptionsFn {
opts = append(opts, awsconfig.WithAssumeRole(a.Config.AssumeRole.RoleARN, a.Config.AssumeRole.ExternalID))
}
const maxRetries = 10
- opts = append(opts, awsconfig.WithRetryer(func() awsv2.Retryer {
+ opts = append(opts, awsconfig.WithRetryer(func() aws.Retryer {
return retry.NewStandard(func(so *retry.StandardOptions) {
so.MaxAttempts = maxRetries
so.Backoff = retry.NewExponentialJitterBackoff(300 * time.Second)
@@ -361,7 +394,7 @@ func (a *awsFetcher) getAccountId(ctx context.Context) (string, error) {
return "", trace.Wrap(err)
}
- return aws.StringValue(req.Account), nil
+ return aws.ToString(req.Account), nil
}
func (a *awsFetcher) DiscoveryConfigName() string {
diff --git a/lib/srv/discovery/fetchers/aws-sync/eks.go b/lib/srv/discovery/fetchers/aws-sync/eks.go
index e4a7cc768ecd2..fc1791b4cb13a 100644
--- a/lib/srv/discovery/fetchers/aws-sync/eks.go
+++ b/lib/srv/discovery/fetchers/aws-sync/eks.go
@@ -22,16 +22,32 @@ import (
"context"
"sync"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/eks"
- "github.com/aws/aws-sdk-go/service/eks/eksiface"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
"github.com/gravitational/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/timestamppb"
+ "google.golang.org/protobuf/types/known/wrapperspb"
accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
)
+// EKSClientGetter returns an EKS client for aws-sync.
+type EKSClientGetter func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (EKSClient, error)
+
+// EKSClient is the subset of the EKS interface we use in aws-sync.
+type EKSClient interface {
+ eks.ListClustersAPIClient
+ eks.DescribeClusterAPIClient
+
+ eks.ListAccessEntriesAPIClient
+ DescribeAccessEntry(ctx context.Context, params *eks.DescribeAccessEntryInput, optFns ...func(*eks.Options)) (*eks.DescribeAccessEntryOutput, error)
+
+ eks.ListAssociatedAccessPoliciesAPIClient
+}
+
// pollAWSEKSClusters is a function that returns a function that fetches
// eks clusters and their access scope levels.
func (a *awsFetcher) pollAWSEKSClusters(ctx context.Context, result *Resources, collectErr func(error)) func() error {
@@ -70,7 +86,8 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust
collectClusters := func(cluster *accessgraphv1alpha.AWSEKSClusterV1,
clusterAssociatedPolicies []*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1,
clusterAccessEntries []*accessgraphv1alpha.AWSEKSClusterAccessEntryV1,
- err error) {
+ err error,
+ ) {
hostsMu.Lock()
defer hostsMu.Unlock()
if err != nil {
@@ -86,41 +103,34 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust
for _, region := range a.Regions {
region := region
eG.Go(func() error {
- eksClient, err := a.CloudClients.GetAWSEKSClient(ctx, region, a.getAWSOptions()...)
+ eksClient, err := a.GetEKSClient(ctx, region, a.getAWSV2Options()...)
if err != nil {
collectClusters(nil, nil, nil, trace.Wrap(err))
return nil
}
var eksClusterNames []string
- // ListClustersPagesWithContext returns a list of EKS cluster names existing in the region.
- err = eksClient.ListClustersPagesWithContext(
- ctx,
- &eks.ListClustersInput{},
- func(output *eks.ListClustersOutput, lastPage bool) bool {
- for _, cluster := range output.Clusters {
- eksClusterNames = append(eksClusterNames, aws.StringValue(cluster))
- }
- return !lastPage
-
- },
- )
- if err != nil {
- oldEKSClusters := sliceFilter(existing.EKSClusters, func(cluster *accessgraphv1alpha.AWSEKSClusterV1) bool {
- return cluster.Region == region && cluster.AccountId == a.AccountID
- })
- oldAccessEntries := sliceFilter(existing.AccessEntries, func(ae *accessgraphv1alpha.AWSEKSClusterAccessEntryV1) bool {
- return ae.Cluster.Region == region && ae.AccountId == a.AccountID
- })
- oldAssociatedPolicies := sliceFilter(existing.AssociatedAccessPolicies, func(ap *accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1) bool {
- return ap.Cluster.Region == region && ap.AccountId == a.AccountID
- })
- hostsMu.Lock()
- output.clusters = append(output.clusters, oldEKSClusters...)
- output.associatedPolicies = append(output.associatedPolicies, oldAssociatedPolicies...)
- output.accessEntry = append(output.accessEntry, oldAccessEntries...)
- hostsMu.Unlock()
+ for p := eks.NewListClustersPaginator(eksClient, nil); p.HasMorePages(); {
+ out, err := p.NextPage(ctx)
+ if err != nil {
+ oldEKSClusters := sliceFilter(existing.EKSClusters, func(cluster *accessgraphv1alpha.AWSEKSClusterV1) bool {
+ return cluster.Region == region && cluster.AccountId == a.AccountID
+ })
+ oldAccessEntries := sliceFilter(existing.AccessEntries, func(ae *accessgraphv1alpha.AWSEKSClusterAccessEntryV1) bool {
+ return ae.Cluster.Region == region && ae.AccountId == a.AccountID
+ })
+ oldAssociatedPolicies := sliceFilter(existing.AssociatedAccessPolicies, func(ap *accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1) bool {
+ return ap.Cluster.Region == region && ap.AccountId == a.AccountID
+ })
+ hostsMu.Lock()
+ output.clusters = append(output.clusters, oldEKSClusters...)
+ output.associatedPolicies = append(output.associatedPolicies, oldAssociatedPolicies...)
+ output.accessEntry = append(output.accessEntry, oldAccessEntries...)
+ hostsMu.Unlock()
+ break
+ }
+ eksClusterNames = append(eksClusterNames, out.Clusters...)
}
for _, cluster := range eksClusterNames {
@@ -134,7 +144,7 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust
return ap.Cluster.Name == cluster && ap.AccountId == a.AccountID && ap.Cluster.Region == region
})
// DescribeClusterWithContext retrieves the cluster details.
- cluster, err := eksClient.DescribeClusterWithContext(ctx, &eks.DescribeClusterInput{
+ cluster, err := eksClient.DescribeCluster(ctx, &eks.DescribeClusterInput{
Name: aws.String(cluster),
},
)
@@ -147,7 +157,7 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust
// if eks cluster only allows CONFIGMAP auth, skip polling of access entries and
// associated policies.
if cluster.Cluster != nil && cluster.Cluster.AccessConfig != nil &&
- aws.StringValue(cluster.Cluster.AccessConfig.AuthenticationMode) == eks.AuthenticationModeConfigMap {
+ cluster.Cluster.AccessConfig.AuthenticationMode == ekstypes.AuthenticationModeConfigMap {
collectClusters(protoCluster, nil, nil, nil)
continue
}
@@ -181,20 +191,20 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust
// awsEKSClusterToProtoCluster converts an eks.Cluster to accessgraphv1alpha.AWSEKSClusterV1
// representation.
-func awsEKSClusterToProtoCluster(cluster *eks.Cluster, region, accountID string) *accessgraphv1alpha.AWSEKSClusterV1 {
+func awsEKSClusterToProtoCluster(cluster *ekstypes.Cluster, region, accountID string) *accessgraphv1alpha.AWSEKSClusterV1 {
var tags []*accessgraphv1alpha.AWSTag
for k, v := range cluster.Tags {
tags = append(tags, &accessgraphv1alpha.AWSTag{
Key: k,
- Value: strPtrToWrapper(v),
+ Value: wrapperspb.String(v),
})
}
return &accessgraphv1alpha.AWSEKSClusterV1{
- Name: aws.StringValue(cluster.Name),
- Arn: aws.StringValue(cluster.Arn),
+ Name: aws.ToString(cluster.Name),
+ Arn: aws.ToString(cluster.Arn),
CreatedAt: awsTimeToProtoTime(cluster.CreatedAt),
- Status: aws.StringValue(cluster.Status),
+ Status: string(cluster.Status),
Region: region,
AccountId: accountID,
Tags: tags,
@@ -203,33 +213,23 @@ func awsEKSClusterToProtoCluster(cluster *eks.Cluster, region, accountID string)
}
// fetchAccessEntries fetches the access entries for the given cluster.
-func (a *awsFetcher) fetchAccessEntries(ctx context.Context, eksClient eksiface.EKSAPI, cluster *accessgraphv1alpha.AWSEKSClusterV1) ([]*accessgraphv1alpha.AWSEKSClusterAccessEntryV1, error) {
+func (a *awsFetcher) fetchAccessEntries(ctx context.Context, eksClient EKSClient, cluster *accessgraphv1alpha.AWSEKSClusterV1) ([]*accessgraphv1alpha.AWSEKSClusterAccessEntryV1, error) {
var accessEntries []string
- var errs []error
- err := eksClient.ListAccessEntriesPagesWithContext(
- ctx,
- &eks.ListAccessEntriesInput{
- ClusterName: aws.String(cluster.Name),
- },
- func(output *eks.ListAccessEntriesOutput, lastPage bool) bool {
- for _, accessEntry := range output.AccessEntries {
- if aws.StringValue(accessEntry) == "" {
- continue
- }
- accessEntries = append(accessEntries, aws.StringValue(accessEntry))
- }
- return !lastPage
- },
- )
- if err != nil {
- errs = append(errs, trace.Wrap(err))
- return nil, trace.NewAggregate(errs...)
+ for p := eks.NewListAccessEntriesPaginator(eksClient,
+ &eks.ListAccessEntriesInput{ClusterName: aws.String(cluster.Name)},
+ ); p.HasMorePages(); {
+ out, err := p.NextPage(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ accessEntries = append(accessEntries, out.AccessEntries...)
}
+ var errs []error
var protoAccessEntries []*accessgraphv1alpha.AWSEKSClusterAccessEntryV1
for _, accessEntry := range accessEntries {
- rsp, err := eksClient.DescribeAccessEntryWithContext(
+ rsp, err := eksClient.DescribeAccessEntry(
ctx,
&eks.DescribeAccessEntryInput{
PrincipalArn: aws.String(accessEntry),
@@ -247,84 +247,81 @@ func (a *awsFetcher) fetchAccessEntries(ctx context.Context, eksClient eksiface.
)
protoAccessEntries = append(protoAccessEntries, protoAccessEntry)
}
+
return protoAccessEntries, trace.NewAggregate(errs...)
}
// awsAccessEntryToProtoAccessEntry converts an eks.AccessEntry to accessgraphv1alpha.AWSEKSClusterV1
-func awsAccessEntryToProtoAccessEntry(accessEntry *eks.AccessEntry, cluster *accessgraphv1alpha.AWSEKSClusterV1, accountID string) *accessgraphv1alpha.AWSEKSClusterAccessEntryV1 {
- var tags []*accessgraphv1alpha.AWSTag
+func awsAccessEntryToProtoAccessEntry(accessEntry *ekstypes.AccessEntry, cluster *accessgraphv1alpha.AWSEKSClusterV1, accountID string) *accessgraphv1alpha.AWSEKSClusterAccessEntryV1 {
+ tags := make([]*accessgraphv1alpha.AWSTag, 0, len(accessEntry.Tags))
for k, v := range accessEntry.Tags {
tags = append(tags, &accessgraphv1alpha.AWSTag{
Key: k,
- Value: strPtrToWrapper(v),
+ Value: wrapperspb.String(v),
})
}
- out := &accessgraphv1alpha.AWSEKSClusterAccessEntryV1{
+
+ return &accessgraphv1alpha.AWSEKSClusterAccessEntryV1{
Cluster: cluster,
- AccessEntryArn: aws.StringValue(accessEntry.AccessEntryArn),
+ AccessEntryArn: aws.ToString(accessEntry.AccessEntryArn),
CreatedAt: awsTimeToProtoTime(accessEntry.CreatedAt),
- KubernetesGroups: aws.StringValueSlice(accessEntry.KubernetesGroups),
- Username: aws.StringValue(accessEntry.Username),
+ KubernetesGroups: accessEntry.KubernetesGroups,
+ Username: aws.ToString(accessEntry.Username),
ModifiedAt: awsTimeToProtoTime(accessEntry.ModifiedAt),
- PrincipalArn: aws.StringValue(accessEntry.PrincipalArn),
- Type: aws.StringValue(accessEntry.Type),
+ PrincipalArn: aws.ToString(accessEntry.PrincipalArn),
+ Type: aws.ToString(accessEntry.Type),
Tags: tags,
AccountId: accountID,
LastSyncTime: timestamppb.Now(),
}
-
- return out
}
// fetchAccessEntries fetches the access entries for the given cluster.
-func (a *awsFetcher) fetchAssociatedPolicies(ctx context.Context, eksClient eksiface.EKSAPI, cluster *accessgraphv1alpha.AWSEKSClusterV1, arns []string) ([]*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1, error) {
+func (a *awsFetcher) fetchAssociatedPolicies(ctx context.Context, eksClient EKSClient, cluster *accessgraphv1alpha.AWSEKSClusterV1, arns []string) ([]*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1, error) {
var associatedPolicies []*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1
var errs []error
+
for _, arn := range arns {
- err := eksClient.ListAssociatedAccessPoliciesPagesWithContext(
- ctx,
+ for p := eks.NewListAssociatedAccessPoliciesPaginator(eksClient,
&eks.ListAssociatedAccessPoliciesInput{
ClusterName: aws.String(cluster.Name),
PrincipalArn: aws.String(arn),
},
- func(output *eks.ListAssociatedAccessPoliciesOutput, lastPage bool) bool {
- for _, policy := range output.AssociatedAccessPolicies {
- associatedPolicies = append(associatedPolicies,
- awsAssociatedAccessPolicy(policy, cluster, arn, a.AccountID),
- )
- }
- return !lastPage
- },
- )
- if err != nil {
- errs = append(errs, trace.Wrap(err))
-
+ ); p.HasMorePages(); {
+ out, err := p.NextPage(ctx)
+ if err != nil {
+ errs = append(errs, err)
+ break
+ }
+ for _, policy := range out.AssociatedAccessPolicies {
+ associatedPolicies = append(associatedPolicies,
+ awsAssociatedAccessPolicy(policy, cluster, arn, a.AccountID),
+ )
+ }
}
-
}
return associatedPolicies, trace.NewAggregate(errs...)
}
// awsAssociatedAccessPolicy converts an eks.AssociatedAccessPolicy to accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1
-func awsAssociatedAccessPolicy(policy *eks.AssociatedAccessPolicy, cluster *accessgraphv1alpha.AWSEKSClusterV1, principalARN, accountID string) *accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1 {
+func awsAssociatedAccessPolicy(policy ekstypes.AssociatedAccessPolicy, cluster *accessgraphv1alpha.AWSEKSClusterV1, principalARN, accountID string) *accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1 {
var accessScope *accessgraphv1alpha.AWSEKSAccessScopeV1
if policy.AccessScope != nil {
accessScope = &accessgraphv1alpha.AWSEKSAccessScopeV1{
- Namespaces: aws.StringValueSlice(policy.AccessScope.Namespaces),
- Type: aws.StringValue(policy.AccessScope.Type),
+ Namespaces: policy.AccessScope.Namespaces,
+ Type: string(policy.AccessScope.Type),
}
}
- out := &accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1{
+
+ return &accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1{
Cluster: cluster,
AssociatedAt: awsTimeToProtoTime(policy.AssociatedAt),
ModifiedAt: awsTimeToProtoTime(policy.ModifiedAt),
PrincipalArn: principalARN,
- PolicyArn: aws.StringValue(policy.PolicyArn),
+ PolicyArn: aws.ToString(policy.PolicyArn),
Scope: accessScope,
AccountId: accountID,
LastSyncTime: timestamppb.Now(),
}
-
- return out
}
diff --git a/lib/srv/discovery/fetchers/aws-sync/eks_test.go b/lib/srv/discovery/fetchers/aws-sync/eks_test.go
index 9c6c395018d95..b38f1ff851a92 100644
--- a/lib/srv/discovery/fetchers/aws-sync/eks_test.go
+++ b/lib/srv/discovery/fetchers/aws-sync/eks_test.go
@@ -24,8 +24,9 @@ import (
"testing"
"time"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/eks"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
@@ -33,23 +34,82 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb"
accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
- "github.com/gravitational/teleport/lib/cloud"
- "github.com/gravitational/teleport/lib/cloud/mocks"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
)
-var (
- date = time.Date(2024, 03, 12, 0, 0, 0, 0, time.UTC)
+var date = time.Date(2024, 0o3, 12, 0, 0, 0, 0, time.UTC)
+
+const (
principalARN = "arn:iam:teleport"
accessEntryARN = "arn:iam:access_entry"
)
+type mockedEKSClient struct {
+ clusters []*ekstypes.Cluster
+ accessEntries []*ekstypes.AccessEntry
+ associatedAccessPolicies []ekstypes.AssociatedAccessPolicy
+}
+
+func (m *mockedEKSClient) DescribeCluster(ctx context.Context, input *eks.DescribeClusterInput, optFns ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) {
+ for _, cluster := range m.clusters {
+ if aws.ToString(cluster.Name) == aws.ToString(input.Name) {
+ return &eks.DescribeClusterOutput{
+ Cluster: cluster,
+ }, nil
+ }
+ }
+ return nil, nil
+}
+
+func (m *mockedEKSClient) ListClusters(ctx context.Context, input *eks.ListClustersInput, optFns ...func(*eks.Options)) (*eks.ListClustersOutput, error) {
+ clusterNames := make([]string, 0, len(m.clusters))
+ for _, cluster := range m.clusters {
+ clusterNames = append(clusterNames, aws.ToString(cluster.Name))
+ }
+ return &eks.ListClustersOutput{
+ Clusters: clusterNames,
+ }, nil
+}
+
+func (m *mockedEKSClient) ListAccessEntries(ctx context.Context, input *eks.ListAccessEntriesInput, optFns ...func(*eks.Options)) (*eks.ListAccessEntriesOutput, error) {
+ accessEntries := make([]string, 0, len(m.accessEntries))
+ for _, accessEntry := range m.accessEntries {
+ accessEntries = append(accessEntries, aws.ToString(accessEntry.AccessEntryArn))
+ }
+ return &eks.ListAccessEntriesOutput{
+ AccessEntries: accessEntries,
+ }, nil
+}
+
+func (m *mockedEKSClient) ListAssociatedAccessPolicies(ctx context.Context, input *eks.ListAssociatedAccessPoliciesInput, optFns ...func(*eks.Options)) (*eks.ListAssociatedAccessPoliciesOutput, error) {
+ return &eks.ListAssociatedAccessPoliciesOutput{
+ AssociatedAccessPolicies: m.associatedAccessPolicies,
+ }, nil
+}
+
+func (m *mockedEKSClient) DescribeAccessEntry(ctx context.Context, input *eks.DescribeAccessEntryInput, optFns ...func(*eks.Options)) (*eks.DescribeAccessEntryOutput, error) {
+ return &eks.DescribeAccessEntryOutput{
+ AccessEntry: &ekstypes.AccessEntry{
+ PrincipalArn: aws.String(principalARN),
+ AccessEntryArn: aws.String(accessEntryARN),
+ CreatedAt: aws.Time(date),
+ ModifiedAt: aws.Time(date),
+ ClusterName: aws.String("cluster1"),
+ Tags: map[string]string{
+ "t1": "t2",
+ },
+ Type: aws.String(string(ekstypes.AccessScopeTypeCluster)),
+ Username: aws.String("teleport"),
+ KubernetesGroups: []string{"teleport"},
+ },
+ }, nil
+}
+
func TestPollAWSEKSClusters(t *testing.T) {
const (
accountID = "12345678"
)
- var (
- regions = []string{"eu-west-1"}
- )
+ regions := []string{"eu-west-1"}
cluster := &accessgraphv1alpha.AWSEKSClusterV1{
Name: "cluster1",
Arn: "arn:us-west1:eks:cluster1",
@@ -58,7 +118,7 @@ func TestPollAWSEKSClusters(t *testing.T) {
Tags: []*accessgraphv1alpha.AWSTag{
{
Key: "tag1",
- Value: nil,
+ Value: wrapperspb.String(""),
},
{
Key: "tag2",
@@ -102,7 +162,7 @@ func TestPollAWSEKSClusters(t *testing.T) {
Cluster: cluster,
PrincipalArn: principalARN,
Scope: &accessgraphv1alpha.AWSEKSAccessScopeV1{
- Type: eks.AccessScopeTypeCluster,
+ Type: string(ekstypes.AccessScopeTypeCluster),
Namespaces: []string{"ns1"},
},
AssociatedAt: timestamppb.New(date),
@@ -116,12 +176,14 @@ func TestPollAWSEKSClusters(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- mockedClients := &cloud.TestCloudClients{
- EKS: &mocks.EKSMock{
- Clusters: eksClusters(),
- AccessEntries: accessEntries(),
- AssociatedPolicies: associatedPolicies(),
- },
+ t.Parallel()
+
+ getEKSClient := func(_ context.Context, _ string, _ ...awsconfig.OptionsFn) (EKSClient, error) {
+ return &mockedEKSClient{
+ clusters: eksClusters(),
+ accessEntries: accessEntries(),
+ associatedAccessPolicies: associatedPolicies(),
+ }, nil
}
var (
@@ -137,20 +199,21 @@ func TestPollAWSEKSClusters(t *testing.T) {
a := &awsFetcher{
Config: Config{
AccountID: accountID,
- CloudClients: mockedClients,
Regions: regions,
Integration: accountID,
+ GetEKSClient: getEKSClient,
},
lastResult: &Resources{},
}
- result := &Resources{}
- execFunc := a.pollAWSEKSClusters(context.Background(), result, collectErr)
+
+ var result Resources
+ execFunc := a.pollAWSEKSClusters(context.Background(), &result, collectErr)
require.NoError(t, execFunc())
require.Empty(t, cmp.Diff(
tt.want,
- result,
+ &result,
protocmp.Transform(),
- // tags originate from a map so we must sort them before comparing.
+ // Tags originate from a map so we must sort them before comparing.
protocmp.SortRepeated(
func(a, b *accessgraphv1alpha.AWSTag) bool {
return a.Key < b.Key
@@ -159,52 +222,50 @@ func TestPollAWSEKSClusters(t *testing.T) {
protocmp.IgnoreFields(&accessgraphv1alpha.AWSEKSClusterV1{}, "last_sync_time"),
protocmp.IgnoreFields(&accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1{}, "last_sync_time"),
protocmp.IgnoreFields(&accessgraphv1alpha.AWSEKSClusterAccessEntryV1{}, "last_sync_time"),
- ),
- )
-
+ ))
})
}
}
-func eksClusters() []*eks.Cluster {
- return []*eks.Cluster{
+func eksClusters() []*ekstypes.Cluster {
+ return []*ekstypes.Cluster{
{
Name: aws.String("cluster1"),
Arn: aws.String("arn:us-west1:eks:cluster1"),
CreatedAt: aws.Time(date),
- Status: aws.String(eks.AddonStatusActive),
- Tags: map[string]*string{
- "tag1": nil,
- "tag2": aws.String("val2"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "tag1": "",
+ "tag2": "val2",
},
},
}
}
-func accessEntries() []*eks.AccessEntry {
- return []*eks.AccessEntry{
+func accessEntries() []*ekstypes.AccessEntry {
+ return []*ekstypes.AccessEntry{
{
PrincipalArn: aws.String(principalARN),
AccessEntryArn: aws.String(accessEntryARN),
CreatedAt: aws.Time(date),
ModifiedAt: aws.Time(date),
ClusterName: aws.String("cluster1"),
- Tags: map[string]*string{
- "t1": aws.String("t2"),
+ Tags: map[string]string{
+ "t1": "t2",
},
- Type: aws.String(eks.AccessScopeTypeCluster),
+ Type: aws.String(string(ekstypes.AccessScopeTypeCluster)),
Username: aws.String("teleport"),
- KubernetesGroups: []*string{aws.String("teleport")},
+ KubernetesGroups: []string{"teleport"},
},
}
}
-func associatedPolicies() []*eks.AssociatedAccessPolicy {
- return []*eks.AssociatedAccessPolicy{
+func associatedPolicies() []ekstypes.AssociatedAccessPolicy {
+ return []ekstypes.AssociatedAccessPolicy{
{
- AccessScope: &eks.AccessScope{
- Namespaces: []*string{aws.String("ns1")},
- Type: aws.String(eks.AccessScopeTypeCluster),
+ AccessScope: &ekstypes.AccessScope{
+ Namespaces: []string{"ns1"},
+ Type: ekstypes.AccessScopeTypeCluster,
},
ModifiedAt: aws.Time(date),
AssociatedAt: aws.Time(date),
diff --git a/lib/srv/discovery/fetchers/aws-sync/rds.go b/lib/srv/discovery/fetchers/aws-sync/rds.go
index 08195e2132e82..f163c49f6b6d3 100644
--- a/lib/srv/discovery/fetchers/aws-sync/rds.go
+++ b/lib/srv/discovery/fetchers/aws-sync/rds.go
@@ -22,8 +22,9 @@ import (
"context"
"sync"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/rds"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/gravitational/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -31,12 +32,18 @@ import (
accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
)
+// rdsClient defines a subset of the AWS RDS client API.
+type rdsClient interface {
+ rds.DescribeDBClustersAPIClient
+ rds.DescribeDBInstancesAPIClient
+}
+
// pollAWSRDSDatabases is a function that returns a function that fetches
// RDS instances and clusters.
func (a *awsFetcher) pollAWSRDSDatabases(ctx context.Context, result *Resources, collectErr func(error)) func() error {
return func() error {
var err error
- result.RDSDatabases, err = a.fetchAWSRDSDatabases(ctx, a.lastResult)
+ result.RDSDatabases, err = a.fetchAWSRDSDatabases(ctx)
if err != nil {
collectErr(trace.Wrap(err, "failed to fetch databases"))
}
@@ -45,7 +52,7 @@ func (a *awsFetcher) pollAWSRDSDatabases(ctx context.Context, result *Resources,
}
// fetchAWSRDSDatabases fetches RDS databases from all regions.
-func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resources) (
+func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context) (
[]*accessgraphv1alpha.AWSRDSDatabaseV1,
error,
) {
@@ -59,14 +66,14 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc
// This is a temporary solution until we have a better way to limit the
// number of concurrent requests.
eG.SetLimit(5)
- collectDBs := func(db *accessgraphv1alpha.AWSRDSDatabaseV1, err error) {
+ collectDBs := func(db []*accessgraphv1alpha.AWSRDSDatabaseV1, err error) {
hostsMu.Lock()
defer hostsMu.Unlock()
if err != nil {
errs = append(errs, err)
}
if db != nil {
- dbs = append(dbs, db)
+ dbs = append(dbs, db...)
}
}
@@ -74,42 +81,14 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc
for _, region := range a.Regions {
region := region
eG.Go(func() error {
- rdsClient, err := a.CloudClients.GetAWSRDSClient(ctx, region, a.getAWSOptions()...)
+ awsCfg, err := a.AWSConfigProvider.GetConfig(ctx, region, a.getAWSV2Options()...)
if err != nil {
collectDBs(nil, trace.Wrap(err))
return nil
}
- err = rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{},
- func(output *rds.DescribeDBInstancesOutput, lastPage bool) bool {
- for _, db := range output.DBInstances {
- // if instance belongs to a cluster, skip it as we want to represent the cluster itself
- // and we pull it using DescribeDBClustersPagesWithContext instead.
- if aws.StringValue(db.DBClusterIdentifier) != "" {
- continue
- }
- protoRDS := awsRDSInstanceToRDS(db, region, a.AccountID)
- collectDBs(protoRDS, nil)
- }
- return !lastPage
- },
- )
- if err != nil {
- collectDBs(nil, trace.Wrap(err))
- }
-
- err = rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{},
- func(output *rds.DescribeDBClustersOutput, lastPage bool) bool {
- for _, db := range output.DBClusters {
- protoRDS := awsRDSClusterToRDS(db, region, a.AccountID)
- collectDBs(protoRDS, nil)
- }
- return !lastPage
- },
- )
- if err != nil {
- collectDBs(nil, trace.Wrap(err))
- }
-
+ clt := a.awsClients.getRDSClient(awsCfg)
+ a.collectDBInstances(ctx, clt, region, collectDBs)
+ a.collectDBClusters(ctx, clt, region, collectDBs)
return nil
})
}
@@ -118,60 +97,123 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc
return dbs, trace.NewAggregate(append(errs, err)...)
}
-// awsRDSInstanceToRDS converts an rds.DBInstance to accessgraphv1alpha.AWSRDSDatabaseV1
+// awsRDSInstanceToRDS converts an rdstypes.DBInstance to accessgraphv1alpha.AWSRDSDatabaseV1
// representation.
-func awsRDSInstanceToRDS(instance *rds.DBInstance, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 {
+func awsRDSInstanceToRDS(instance *rdstypes.DBInstance, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 {
var tags []*accessgraphv1alpha.AWSTag
for _, v := range instance.TagList {
tags = append(tags, &accessgraphv1alpha.AWSTag{
- Key: aws.StringValue(v.Key),
+ Key: aws.ToString(v.Key),
Value: strPtrToWrapper(v.Value),
})
}
return &accessgraphv1alpha.AWSRDSDatabaseV1{
- Name: aws.StringValue(instance.DBInstanceIdentifier),
- Arn: aws.StringValue(instance.DBInstanceArn),
+ Name: aws.ToString(instance.DBInstanceIdentifier),
+ Arn: aws.ToString(instance.DBInstanceArn),
CreatedAt: awsTimeToProtoTime(instance.InstanceCreateTime),
- Status: aws.StringValue(instance.DBInstanceStatus),
+ Status: aws.ToString(instance.DBInstanceStatus),
Region: region,
AccountId: accountID,
Tags: tags,
EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
- Engine: aws.StringValue(instance.Engine),
- Version: aws.StringValue(instance.EngineVersion),
+ Engine: aws.ToString(instance.Engine),
+ Version: aws.ToString(instance.EngineVersion),
},
IsCluster: false,
- ResourceId: aws.StringValue(instance.DbiResourceId),
+ ResourceId: aws.ToString(instance.DbiResourceId),
LastSyncTime: timestamppb.Now(),
}
}
-// awsRDSInstanceToRDS converts an rds.DBCluster to accessgraphv1alpha.AWSRDSDatabaseV1
+// awsRDSInstanceToRDS converts an rdstypes.DBCluster to accessgraphv1alpha.AWSRDSDatabaseV1
// representation.
-func awsRDSClusterToRDS(instance *rds.DBCluster, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 {
+func awsRDSClusterToRDS(instance *rdstypes.DBCluster, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 {
var tags []*accessgraphv1alpha.AWSTag
for _, v := range instance.TagList {
tags = append(tags, &accessgraphv1alpha.AWSTag{
- Key: aws.StringValue(v.Key),
+ Key: aws.ToString(v.Key),
Value: strPtrToWrapper(v.Value),
})
}
return &accessgraphv1alpha.AWSRDSDatabaseV1{
- Name: aws.StringValue(instance.DBClusterIdentifier),
- Arn: aws.StringValue(instance.DBClusterArn),
+ Name: aws.ToString(instance.DBClusterIdentifier),
+ Arn: aws.ToString(instance.DBClusterArn),
CreatedAt: awsTimeToProtoTime(instance.ClusterCreateTime),
- Status: aws.StringValue(instance.Status),
+ Status: aws.ToString(instance.Status),
Region: region,
AccountId: accountID,
Tags: tags,
EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
- Engine: aws.StringValue(instance.Engine),
- Version: aws.StringValue(instance.EngineVersion),
+ Engine: aws.ToString(instance.Engine),
+ Version: aws.ToString(instance.EngineVersion),
},
IsCluster: true,
- ResourceId: aws.StringValue(instance.DbClusterResourceId),
+ ResourceId: aws.ToString(instance.DbClusterResourceId),
LastSyncTime: timestamppb.Now(),
}
}
+
+func (a *awsFetcher) collectDBInstances(ctx context.Context,
+ clt rdsClient,
+ region string,
+ collectDBs func([]*accessgraphv1alpha.AWSRDSDatabaseV1, error),
+) {
+ pager := rds.NewDescribeDBInstancesPaginator(clt,
+ &rds.DescribeDBInstancesInput{},
+ func(ddpo *rds.DescribeDBInstancesPaginatorOptions) {
+ ddpo.StopOnDuplicateToken = true
+ },
+ )
+ var instances []*accessgraphv1alpha.AWSRDSDatabaseV1
+ for pager.HasMorePages() {
+ page, err := pager.NextPage(ctx)
+ if err != nil {
+ old := sliceFilter(a.lastResult.RDSDatabases, func(db *accessgraphv1alpha.AWSRDSDatabaseV1) bool {
+ return !db.IsCluster && db.Region == region && db.AccountId == a.AccountID
+ })
+ collectDBs(old, trace.Wrap(err))
+ return
+ }
+ for _, db := range page.DBInstances {
+ // if instance belongs to a cluster, skip it as we want to represent the cluster itself
+ // and we pull it using DescribeDBClustersPaginator instead.
+ if aws.ToString(db.DBClusterIdentifier) != "" {
+ continue
+ }
+ protoRDS := awsRDSInstanceToRDS(&db, region, a.AccountID)
+ instances = append(instances, protoRDS)
+ }
+ }
+ collectDBs(instances, nil)
+}
+
+func (a *awsFetcher) collectDBClusters(
+ ctx context.Context,
+ clt rdsClient,
+ region string,
+ collectDBs func([]*accessgraphv1alpha.AWSRDSDatabaseV1, error),
+) {
+ pager := rds.NewDescribeDBClustersPaginator(clt, &rds.DescribeDBClustersInput{},
+ func(ddpo *rds.DescribeDBClustersPaginatorOptions) {
+ ddpo.StopOnDuplicateToken = true
+ },
+ )
+ var clusters []*accessgraphv1alpha.AWSRDSDatabaseV1
+ for pager.HasMorePages() {
+ page, err := pager.NextPage(ctx)
+ if err != nil {
+ old := sliceFilter(a.lastResult.RDSDatabases, func(db *accessgraphv1alpha.AWSRDSDatabaseV1) bool {
+ return db.IsCluster && db.Region == region && db.AccountId == a.AccountID
+ })
+ collectDBs(old, trace.Wrap(err))
+ return
+ }
+ for _, db := range page.DBClusters {
+ protoRDS := awsRDSClusterToRDS(&db, region, a.AccountID)
+ clusters = append(clusters, protoRDS)
+ }
+ }
+ collectDBs(clusters, nil)
+}
diff --git a/lib/srv/discovery/fetchers/aws-sync/rds_test.go b/lib/srv/discovery/fetchers/aws-sync/rds_test.go
index bed0811d88e1d..b228264b7c834 100644
--- a/lib/srv/discovery/fetchers/aws-sync/rds_test.go
+++ b/lib/srv/discovery/fetchers/aws-sync/rds_test.go
@@ -20,19 +20,19 @@ package aws_sync
import (
"context"
- "sync"
"testing"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/rds"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
+ "github.com/gravitational/teleport/api/types"
accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
- "github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
)
@@ -44,86 +44,117 @@ func TestPollAWSRDS(t *testing.T) {
regions = []string{"eu-west-1"}
)
- tests := []struct {
- name string
- want *Resources
- }{
- {
- name: "poll rds databases",
- want: &Resources{
- RDSDatabases: []*accessgraphv1alpha.AWSRDSDatabaseV1{
+ awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(
+ types.Metadata{Name: "integration-test"},
+ &types.AWSOIDCIntegrationSpecV1{
+ RoleARN: "arn:aws:sts::123456789012:role/TestRole",
+ },
+ )
+ require.NoError(t, err)
+
+ resourcesFixture := Resources{
+ RDSDatabases: []*accessgraphv1alpha.AWSRDSDatabaseV1{
+ {
+ Arn: "arn:us-west1:rds:instance1",
+ Status: string(rdstypes.DBProxyStatusAvailable),
+ Name: "db1",
+ EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
+ Engine: string(rdstypes.EngineFamilyMysql),
+ Version: "v1.1",
+ },
+ CreatedAt: timestamppb.New(date),
+ Tags: []*accessgraphv1alpha.AWSTag{
{
- Arn: "arn:us-west1:rds:instance1",
- Status: rds.DBProxyStatusAvailable,
- Name: "db1",
- EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
- Engine: rds.EngineFamilyMysql,
- Version: "v1.1",
- },
- CreatedAt: timestamppb.New(date),
- Tags: []*accessgraphv1alpha.AWSTag{
- {
- Key: "tag",
- Value: wrapperspb.String("val"),
- },
- },
- Region: "eu-west-1",
- IsCluster: false,
- AccountId: "12345678",
- ResourceId: "db1",
+ Key: "tag",
+ Value: wrapperspb.String("val"),
},
+ },
+ Region: "eu-west-1",
+ IsCluster: false,
+ AccountId: "12345678",
+ ResourceId: "db1",
+ },
+ {
+ Arn: "arn:us-west1:rds:cluster1",
+ Status: string(rdstypes.DBProxyStatusAvailable),
+ Name: "cluster1",
+ EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
+ Engine: string(rdstypes.EngineFamilyMysql),
+ Version: "v1.1",
+ },
+ CreatedAt: timestamppb.New(date),
+ Tags: []*accessgraphv1alpha.AWSTag{
{
- Arn: "arn:us-west1:rds:cluster1",
- Status: rds.DBProxyStatusAvailable,
- Name: "cluster1",
- EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
- Engine: rds.EngineFamilyMysql,
- Version: "v1.1",
- },
- CreatedAt: timestamppb.New(date),
- Tags: []*accessgraphv1alpha.AWSTag{
- {
- Key: "tag",
- Value: wrapperspb.String("val"),
- },
- },
- Region: "eu-west-1",
- IsCluster: true,
- AccountId: "12345678",
- ResourceId: "cluster1",
+ Key: "tag",
+ Value: wrapperspb.String("val"),
},
},
+ Region: "eu-west-1",
+ IsCluster: true,
+ AccountId: "12345678",
+ ResourceId: "cluster1",
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ fetcherConfigOpt func(*awsFetcher)
+ want *Resources
+ checkError func(*testing.T, error)
+ }{
+ {
+ name: "poll rds databases",
+ want: &resourcesFixture,
+ fetcherConfigOpt: func(a *awsFetcher) {
+ a.awsClients = fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBInstances: dbInstances(),
+ DBClusters: dbClusters(),
+ },
+ }
+ },
+ checkError: func(t *testing.T, err error) {
+ require.NoError(t, err)
+ },
+ },
+ {
+ name: "reuse last synced databases on failure",
+ want: &resourcesFixture,
+ fetcherConfigOpt: func(a *awsFetcher) {
+ a.awsClients = fakeAWSClients{
+ rdsClient: &mocks.RDSClient{Unauth: true},
+ }
+ a.lastResult = &resourcesFixture
+ },
+ checkError: func(t *testing.T, err error) {
+ require.Error(t, err)
+ require.ErrorContains(t, err, "failed to fetch databases")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- mockedClients := &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBInstances: dbInstances(),
- DBClusters: dbClusters(),
- },
- }
-
- var (
- errs []error
- mu sync.Mutex
- )
-
- collectErr := func(err error) {
- mu.Lock()
- defer mu.Unlock()
- errs = append(errs, err)
- }
a := &awsFetcher{
Config: Config{
- AccountID: accountID,
- CloudClients: mockedClients,
- Regions: regions,
- Integration: accountID,
+ AccountID: accountID,
+ AWSConfigProvider: &mocks.AWSConfigProvider{
+ OIDCIntegrationClient: &mocks.FakeOIDCIntegrationClient{
+ Integration: awsOIDCIntegration,
+ Token: "fake-oidc-token",
+ },
+ },
+ Regions: regions,
+ Integration: awsOIDCIntegration.GetName(),
},
}
+ if tt.fetcherConfigOpt != nil {
+ tt.fetcherConfigOpt(a)
+ }
result := &Resources{}
+ collectErr := func(err error) {
+ tt.checkError(t, err)
+ }
execFunc := a.pollAWSRDSDatabases(context.Background(), result, collectErr)
require.NoError(t, execFunc())
require.Empty(t, cmp.Diff(
@@ -144,16 +175,16 @@ func TestPollAWSRDS(t *testing.T) {
}
}
-func dbInstances() []*rds.DBInstance {
- return []*rds.DBInstance{
+func dbInstances() []rdstypes.DBInstance {
+ return []rdstypes.DBInstance{
{
DBInstanceIdentifier: aws.String("db1"),
DBInstanceArn: aws.String("arn:us-west1:rds:instance1"),
InstanceCreateTime: aws.Time(date),
- Engine: aws.String(rds.EngineFamilyMysql),
- DBInstanceStatus: aws.String(rds.DBProxyStatusAvailable),
+ Engine: aws.String(string(rdstypes.EngineFamilyMysql)),
+ DBInstanceStatus: aws.String(string(rdstypes.DBProxyStatusAvailable)),
EngineVersion: aws.String("v1.1"),
- TagList: []*rds.Tag{
+ TagList: []rdstypes.Tag{
{
Key: aws.String("tag"),
Value: aws.String("val"),
@@ -164,16 +195,16 @@ func dbInstances() []*rds.DBInstance {
}
}
-func dbClusters() []*rds.DBCluster {
- return []*rds.DBCluster{
+func dbClusters() []rdstypes.DBCluster {
+ return []rdstypes.DBCluster{
{
DBClusterIdentifier: aws.String("cluster1"),
DBClusterArn: aws.String("arn:us-west1:rds:cluster1"),
ClusterCreateTime: aws.Time(date),
- Engine: aws.String(rds.EngineFamilyMysql),
- Status: aws.String(rds.DBProxyStatusAvailable),
+ Engine: aws.String(string(rdstypes.EngineFamilyMysql)),
+ Status: aws.String(string(rdstypes.DBProxyStatusAvailable)),
EngineVersion: aws.String("v1.1"),
- TagList: []*rds.Tag{
+ TagList: []rdstypes.Tag{
{
Key: aws.String("tag"),
Value: aws.String("val"),
@@ -183,3 +214,11 @@ func dbClusters() []*rds.DBCluster {
},
}
}
+
+type fakeAWSClients struct {
+ rdsClient rdsClient
+}
+
+func (f fakeAWSClients) getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient {
+ return f.rdsClient
+}
diff --git a/lib/srv/discovery/fetchers/db/aws.go b/lib/srv/discovery/fetchers/db/aws.go
index d6d70912d7092..24de91e83e309 100644
--- a/lib/srv/discovery/fetchers/db/aws.go
+++ b/lib/srv/discovery/fetchers/db/aws.go
@@ -23,8 +23,6 @@ import (
"fmt"
"log/slog"
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/redshift"
"github.com/gravitational/trace"
"github.com/gravitational/teleport"
@@ -74,8 +72,8 @@ type awsFetcherConfig struct {
// ie teleport.yaml/discovery_service..
DiscoveryConfigName string
- // redshiftClientProviderFn provides an AWS Redshift client.
- redshiftClientProviderFn RedshiftClientProviderFunc
+ // awsClients provides AWS SDK v2 clients.
+ awsClients AWSClientProvider
}
// CheckAndSetDefaults validates the config and sets defaults.
@@ -109,10 +107,8 @@ func (cfg *awsFetcherConfig) CheckAndSetDefaults(component string) error {
)
}
- if cfg.redshiftClientProviderFn == nil {
- cfg.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient {
- return redshift.NewFromConfig(cfg, optFns...)
- }
+ if cfg.awsClients == nil {
+ cfg.awsClients = defaultAWSClients{}
}
return nil
}
diff --git a/lib/srv/discovery/fetchers/db/aws_docdb.go b/lib/srv/discovery/fetchers/db/aws_docdb.go
index a6a604be340eb..ef1920d83d6b8 100644
--- a/lib/srv/discovery/fetchers/db/aws_docdb.go
+++ b/lib/srv/discovery/fetchers/db/aws_docdb.go
@@ -21,14 +21,14 @@ package db
import (
"context"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/rds"
- "github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
libcloudaws "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
@@ -39,13 +39,6 @@ func newDocumentDBFetcher(cfg awsFetcherConfig) (common.Fetcher, error) {
}
// rdsDocumentDBFetcher retrieves DocumentDB clusters.
-//
-// Note that AWS DocumentDB internally uses the RDS APIs:
-// https://github.com/aws/aws-sdk-go/blob/3248e69e16aa601ffa929be53a52439425257e5e/service/docdb/service.go#L33
-// The interfaces/structs in "services/docdb" are usually a subset of those in
-// "services/rds".
-//
-// TODO(greedy52) switch to aws-sdk-go-v2/services/docdb.
type rdsDocumentDBFetcher struct{}
func (f *rdsDocumentDBFetcher) ComponentShortName() string {
@@ -54,21 +47,22 @@ func (f *rdsDocumentDBFetcher) ComponentShortName() string {
// GetDatabases returns a list of database resources representing DocumentDB endpoints.
func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) {
- rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region,
- cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID),
- cloud.WithCredentialsMaybeIntegration(cfg.Integration),
+ awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region,
+ awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID),
+ awsconfig.WithCredentialsMaybeIntegration(cfg.Integration),
)
if err != nil {
return nil, trace.Wrap(err)
}
- clusters, err := f.getAllDBClusters(ctx, rdsClient)
+ clt := cfg.awsClients.GetRDSClient(awsCfg)
+ clusters, err := f.getAllDBClusters(ctx, clt)
if err != nil {
- return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err))
+ return nil, trace.Wrap(err)
}
databases := make(types.Databases, 0)
for _, cluster := range clusters {
- if !libcloudaws.IsDocumentDBClusterSupported(cluster) {
+ if !libcloudaws.IsDocumentDBClusterSupported(&cluster) {
cfg.Logger.DebugContext(ctx, "DocumentDB cluster doesn't support IAM authentication. Skipping.",
"cluster", aws.StringValue(cluster.DBClusterIdentifier),
"engine_version", aws.StringValue(cluster.EngineVersion))
@@ -82,7 +76,7 @@ func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcher
continue
}
- dbs, err := common.NewDatabasesFromDocumentDBCluster(cluster)
+ dbs, err := common.NewDatabasesFromDocumentDBCluster(&cluster)
if err != nil {
cfg.Logger.WarnContext(ctx, "Could not convert DocumentDB cluster to database resources.",
"cluster", aws.StringValue(cluster.DBClusterIdentifier),
@@ -93,15 +87,23 @@ func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcher
return databases, nil
}
-func (f *rdsDocumentDBFetcher) getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI) ([]*rds.DBCluster, error) {
- var pageNum int
- var clusters []*rds.DBCluster
- err := rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{
- Filters: rdsEngineFilter([]string{"docdb"}),
- }, func(ddo *rds.DescribeDBClustersOutput, lastPage bool) bool {
- pageNum++
- clusters = append(clusters, ddo.DBClusters...)
- return pageNum <= maxAWSPages
- })
- return clusters, trace.Wrap(err)
+func (f *rdsDocumentDBFetcher) getAllDBClusters(ctx context.Context, clt RDSClient) ([]rdstypes.DBCluster, error) {
+ pager := rds.NewDescribeDBClustersPaginator(clt,
+ &rds.DescribeDBClustersInput{
+ Filters: rdsEngineFilter([]string{"docdb"}),
+ },
+ func(pagerOpts *rds.DescribeDBClustersPaginatorOptions) {
+ pagerOpts.StopOnDuplicateToken = true
+ },
+ )
+
+ var clusters []rdstypes.DBCluster
+ for i := 0; i < maxAWSPages && pager.HasMorePages(); i++ {
+ page, err := pager.NextPage(ctx)
+ if err != nil {
+ return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err))
+ }
+ clusters = append(clusters, page.DBClusters...)
+ }
+ return clusters, nil
}
diff --git a/lib/srv/discovery/fetchers/db/aws_docdb_test.go b/lib/srv/discovery/fetchers/db/aws_docdb_test.go
index 4ae7cfee582f0..5f71a805f8131 100644
--- a/lib/srv/discovery/fetchers/db/aws_docdb_test.go
+++ b/lib/srv/discovery/fetchers/db/aws_docdb_test.go
@@ -21,12 +21,11 @@ package db
import (
"testing"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/rds"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
@@ -34,16 +33,16 @@ import (
func TestDocumentDBFetcher(t *testing.T) {
t.Parallel()
- docdbEngine := &rds.DBEngineVersion{
+ docdbEngine := &rdstypes.DBEngineVersion{
Engine: aws.String("docdb"),
}
clusterProd := mocks.DocumentDBCluster("cluster1", "us-east-1", envProdLabels, mocks.WithDocumentDBClusterReader)
clusterDev := mocks.DocumentDBCluster("cluster2", "us-east-1", envDevLabels)
- clusterNotAvailable := mocks.DocumentDBCluster("cluster3", "us-east-1", envDevLabels, func(cluster *rds.DBCluster) {
+ clusterNotAvailable := mocks.DocumentDBCluster("cluster3", "us-east-1", envDevLabels, func(cluster *rdstypes.DBCluster) {
cluster.Status = aws.String("creating")
})
- clusterNotSupported := mocks.DocumentDBCluster("cluster4", "us-east-1", envDevLabels, func(cluster *rds.DBCluster) {
+ clusterNotSupported := mocks.DocumentDBCluster("cluster4", "us-east-1", envDevLabels, func(cluster *rdstypes.DBCluster) {
cluster.EngineVersion = aws.String("4.0.0")
})
@@ -53,10 +52,12 @@ func TestDocumentDBFetcher(t *testing.T) {
tests := []awsFetcherTest{
{
name: "fetch all",
- inputClients: &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBClusters: []*rds.DBCluster{clusterProd, clusterDev},
- DBEngineVersions: []*rds.DBEngineVersion{docdbEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterDev},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine},
+ },
},
},
inputMatchers: []types.AWSMatcher{
@@ -70,10 +71,12 @@ func TestDocumentDBFetcher(t *testing.T) {
},
{
name: "filter by labels",
- inputClients: &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBClusters: []*rds.DBCluster{clusterProd, clusterDev},
- DBEngineVersions: []*rds.DBEngineVersion{docdbEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterDev},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine},
+ },
},
},
inputMatchers: []types.AWSMatcher{
@@ -87,10 +90,12 @@ func TestDocumentDBFetcher(t *testing.T) {
},
{
name: "skip unsupported databases",
- inputClients: &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBClusters: []*rds.DBCluster{clusterProd, clusterNotSupported},
- DBEngineVersions: []*rds.DBEngineVersion{docdbEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterNotSupported},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine},
+ },
},
},
inputMatchers: []types.AWSMatcher{
@@ -104,10 +109,12 @@ func TestDocumentDBFetcher(t *testing.T) {
},
{
name: "skip unavailable databases",
- inputClients: &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBClusters: []*rds.DBCluster{clusterProd, clusterNotAvailable},
- DBEngineVersions: []*rds.DBEngineVersion{docdbEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterNotAvailable},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine},
+ },
},
},
inputMatchers: []types.AWSMatcher{
@@ -123,7 +130,7 @@ func TestDocumentDBFetcher(t *testing.T) {
testAWSFetchers(t, tests...)
}
-func mustMakeDocumentDBDatabases(t *testing.T, cluster *rds.DBCluster) types.Databases {
+func mustMakeDocumentDBDatabases(t *testing.T, cluster *rdstypes.DBCluster) types.Databases {
t.Helper()
databases, err := common.NewDatabasesFromDocumentDBCluster(cluster)
diff --git a/lib/srv/discovery/fetchers/db/aws_rds.go b/lib/srv/discovery/fetchers/db/aws_rds.go
index 639835f2b75a2..1b438873c8726 100644
--- a/lib/srv/discovery/fetchers/db/aws_rds.go
+++ b/lib/srv/discovery/fetchers/db/aws_rds.go
@@ -23,18 +23,27 @@ import (
"log/slog"
"strings"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/rds"
- "github.com/aws/aws-sdk-go/service/rds/rdsiface"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
libcloudaws "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
+// RDSClient is a subset of the AWS RDS API.
+type RDSClient interface {
+ rds.DescribeDBClustersAPIClient
+ rds.DescribeDBInstancesAPIClient
+ rds.DescribeDBProxiesAPIClient
+ rds.DescribeDBProxyEndpointsAPIClient
+ ListTagsForResource(context.Context, *rds.ListTagsForResourceInput, ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error)
+}
+
// newRDSDBInstancesFetcher returns a new AWS fetcher for RDS databases.
func newRDSDBInstancesFetcher(cfg awsFetcherConfig) (common.Fetcher, error) {
return newAWSFetcher(cfg, &rdsDBInstancesPlugin{})
@@ -49,40 +58,41 @@ func (f *rdsDBInstancesPlugin) ComponentShortName() string {
// GetDatabases returns a list of database resources representing RDS instances.
func (f *rdsDBInstancesPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) {
- rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region,
- cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID),
- cloud.WithCredentialsMaybeIntegration(cfg.Integration),
+ awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region,
+ awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID),
+ awsconfig.WithCredentialsMaybeIntegration(cfg.Integration),
)
if err != nil {
return nil, trace.Wrap(err)
}
- instances, err := getAllDBInstances(ctx, rdsClient, maxAWSPages, cfg.Logger)
+ clt := cfg.awsClients.GetRDSClient(awsCfg)
+ instances, err := getAllDBInstances(ctx, clt, maxAWSPages, cfg.Logger)
if err != nil {
- return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err))
+ return nil, trace.Wrap(err)
}
databases := make(types.Databases, 0, len(instances))
for _, instance := range instances {
- if !libcloudaws.IsRDSInstanceSupported(instance) {
+ if !libcloudaws.IsRDSInstanceSupported(&instance) {
cfg.Logger.DebugContext(ctx, "Skipping RDS instance that does not support IAM authentication",
- "instance", aws.StringValue(instance.DBInstanceIdentifier),
- "engine_mode", aws.StringValue(instance.Engine),
- "engine_version", aws.StringValue(instance.EngineVersion),
+ "instance", aws.ToString(instance.DBInstanceIdentifier),
+ "engine_mode", aws.ToString(instance.Engine),
+ "engine_version", aws.ToString(instance.EngineVersion),
)
continue
}
if !libcloudaws.IsRDSInstanceAvailable(instance.DBInstanceStatus, instance.DBInstanceIdentifier) {
cfg.Logger.DebugContext(ctx, "Skipping unavailable RDS instance",
- "instance", aws.StringValue(instance.DBInstanceIdentifier),
- "status", aws.StringValue(instance.DBInstanceStatus),
+ "instance", aws.ToString(instance.DBInstanceIdentifier),
+ "status", aws.ToString(instance.DBInstanceStatus),
)
continue
}
- database, err := common.NewDatabaseFromRDSInstance(instance)
+ database, err := common.NewDatabaseFromRDSInstance(&instance)
if err != nil {
cfg.Logger.WarnContext(ctx, "Could not convert RDS instance to database resource",
- "instance", aws.StringValue(instance.DBInstanceIdentifier),
+ "instance", aws.ToString(instance.DBInstanceIdentifier),
"error", err,
)
} else {
@@ -94,36 +104,40 @@ func (f *rdsDBInstancesPlugin) GetDatabases(ctx context.Context, cfg *awsFetcher
// getAllDBInstances fetches all RDS instances using the provided client, up
// to the specified max number of pages.
-func getAllDBInstances(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, logger *slog.Logger) ([]*rds.DBInstance, error) {
- return getAllDBInstancesWithFilters(ctx, rdsClient, maxPages, rdsInstanceEngines(), rdsEmptyFilter(), logger)
+func getAllDBInstances(ctx context.Context, clt RDSClient, maxPages int, logger *slog.Logger) ([]rdstypes.DBInstance, error) {
+ return getAllDBInstancesWithFilters(ctx, clt, maxPages, rdsInstanceEngines(), rdsEmptyFilter(), logger)
}
// findDBInstancesForDBCluster returns the DBInstances associated with a given DB Cluster Identifier
-func findDBInstancesForDBCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, dbClusterIdentifier string, logger *slog.Logger) ([]*rds.DBInstance, error) {
- return getAllDBInstancesWithFilters(ctx, rdsClient, maxPages, auroraEngines(), rdsClusterIDFilter(dbClusterIdentifier), logger)
+func findDBInstancesForDBCluster(ctx context.Context, clt RDSClient, maxPages int, dbClusterIdentifier string, logger *slog.Logger) ([]rdstypes.DBInstance, error) {
+ return getAllDBInstancesWithFilters(ctx, clt, maxPages, auroraEngines(), rdsClusterIDFilter(dbClusterIdentifier), logger)
}
// getAllDBInstancesWithFilters fetches all RDS instances matching the filters using the provided client, up
// to the specified max number of pages.
-func getAllDBInstancesWithFilters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, engines []string, baseFilters []*rds.Filter, logger *slog.Logger) ([]*rds.DBInstance, error) {
- var instances []*rds.DBInstance
- err := retryWithIndividualEngineFilters(ctx, logger, engines, func(engineFilters []*rds.Filter) error {
- var pageNum int
- var out []*rds.DBInstance
- err := rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{
- Filters: append(engineFilters, baseFilters...),
- }, func(ddo *rds.DescribeDBInstancesOutput, lastPage bool) bool {
- pageNum++
- instances = append(instances, ddo.DBInstances...)
- return pageNum <= maxPages
- })
- if err == nil {
- // only append to instances on nil error, just in case we have to retry.
- instances = append(instances, out...)
+func getAllDBInstancesWithFilters(ctx context.Context, clt RDSClient, maxPages int, engines []string, baseFilters []rdstypes.Filter, logger *slog.Logger) ([]rdstypes.DBInstance, error) {
+ var out []rdstypes.DBInstance
+ err := retryWithIndividualEngineFilters(ctx, logger, engines, func(engineFilters []rdstypes.Filter) error {
+ pager := rds.NewDescribeDBInstancesPaginator(clt,
+ &rds.DescribeDBInstancesInput{
+ Filters: append(engineFilters, baseFilters...),
+ },
+ func(dcpo *rds.DescribeDBInstancesPaginatorOptions) {
+ dcpo.StopOnDuplicateToken = true
+ },
+ )
+ var instances []rdstypes.DBInstance
+ for i := 0; i < maxPages && pager.HasMorePages(); i++ {
+ page, err := pager.NextPage(ctx)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ instances = append(instances, page.DBInstances...)
}
- return trace.Wrap(err)
+ out = instances
+ return nil
})
- return instances, trace.Wrap(err)
+ return out, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err))
}
// newRDSAuroraClustersFetcher returns a new AWS fetcher for RDS Aurora
@@ -141,48 +155,49 @@ func (f *rdsAuroraClustersPlugin) ComponentShortName() string {
// GetDatabases returns a list of database resources representing RDS clusters.
func (f *rdsAuroraClustersPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) {
- rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region,
- cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID),
- cloud.WithCredentialsMaybeIntegration(cfg.Integration),
+ awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region,
+ awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID),
+ awsconfig.WithCredentialsMaybeIntegration(cfg.Integration),
)
if err != nil {
return nil, trace.Wrap(err)
}
- clusters, err := getAllDBClusters(ctx, rdsClient, maxAWSPages, cfg.Logger)
+ clt := cfg.awsClients.GetRDSClient(awsCfg)
+ clusters, err := getAllDBClusters(ctx, clt, maxAWSPages, cfg.Logger)
if err != nil {
- return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err))
+ return nil, trace.Wrap(err)
}
databases := make(types.Databases, 0, len(clusters))
for _, cluster := range clusters {
- if !libcloudaws.IsRDSClusterSupported(cluster) {
+ if !libcloudaws.IsRDSClusterSupported(&cluster) {
cfg.Logger.DebugContext(ctx, "Skipping Aurora cluster that does not support IAM authentication",
- "cluster", aws.StringValue(cluster.DBClusterIdentifier),
- "engine_mode", aws.StringValue(cluster.EngineMode),
- "engine_version", aws.StringValue(cluster.EngineVersion),
+ "cluster", aws.ToString(cluster.DBClusterIdentifier),
+ "engine_mode", aws.ToString(cluster.EngineMode),
+ "engine_version", aws.ToString(cluster.EngineVersion),
)
continue
}
if !libcloudaws.IsDBClusterAvailable(cluster.Status, cluster.DBClusterIdentifier) {
cfg.Logger.DebugContext(ctx, "Skipping unavailable Aurora cluster",
- "instance", aws.StringValue(cluster.DBClusterIdentifier),
- "status", aws.StringValue(cluster.Status),
+ "instance", aws.ToString(cluster.DBClusterIdentifier),
+ "status", aws.ToString(cluster.Status),
)
continue
}
- rdsDBInstances, err := findDBInstancesForDBCluster(ctx, rdsClient, maxAWSPages, aws.StringValue(cluster.DBClusterIdentifier), cfg.Logger)
+ rdsDBInstances, err := findDBInstancesForDBCluster(ctx, clt, maxAWSPages, aws.ToString(cluster.DBClusterIdentifier), cfg.Logger)
if err != nil || len(rdsDBInstances) == 0 {
cfg.Logger.WarnContext(ctx, "Could not fetch Member Instance for DB Cluster",
- "instance", aws.StringValue(cluster.DBClusterIdentifier),
+ "instance", aws.ToString(cluster.DBClusterIdentifier),
"error", err,
)
}
- dbs, err := common.NewDatabasesFromRDSCluster(cluster, rdsDBInstances)
+ dbs, err := common.NewDatabasesFromRDSCluster(&cluster, rdsDBInstances)
if err != nil {
cfg.Logger.WarnContext(ctx, "Could not convert RDS cluster to database resources",
- "identifier", aws.StringValue(cluster.DBClusterIdentifier),
+ "identifier", aws.ToString(cluster.DBClusterIdentifier),
"error", err,
)
}
@@ -193,25 +208,30 @@ func (f *rdsAuroraClustersPlugin) GetDatabases(ctx context.Context, cfg *awsFetc
// getAllDBClusters fetches all RDS clusters using the provided client, up to
// the specified max number of pages.
-func getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, logger *slog.Logger) ([]*rds.DBCluster, error) {
- var clusters []*rds.DBCluster
- err := retryWithIndividualEngineFilters(ctx, logger, auroraEngines(), func(filters []*rds.Filter) error {
- var pageNum int
- var out []*rds.DBCluster
- err := rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{
- Filters: filters,
- }, func(ddo *rds.DescribeDBClustersOutput, lastPage bool) bool {
- pageNum++
- out = append(out, ddo.DBClusters...)
- return pageNum <= maxPages
- })
- if err == nil {
- // only append to clusters on nil error, just in case we have to retry.
- clusters = append(clusters, out...)
+func getAllDBClusters(ctx context.Context, clt RDSClient, maxPages int, logger *slog.Logger) ([]rdstypes.DBCluster, error) {
+ var out []rdstypes.DBCluster
+ err := retryWithIndividualEngineFilters(ctx, logger, auroraEngines(), func(filters []rdstypes.Filter) error {
+ pager := rds.NewDescribeDBClustersPaginator(clt,
+ &rds.DescribeDBClustersInput{
+ Filters: filters,
+ },
+ func(pagerOpts *rds.DescribeDBClustersPaginatorOptions) {
+ pagerOpts.StopOnDuplicateToken = true
+ },
+ )
+
+ var clusters []rdstypes.DBCluster
+ for i := 0; i < maxPages && pager.HasMorePages(); i++ {
+ page, err := pager.NextPage(ctx)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ clusters = append(clusters, page.DBClusters...)
}
- return trace.Wrap(err)
+ out = clusters
+ return nil
})
- return clusters, trace.Wrap(err)
+ return out, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err))
}
// rdsInstanceEngines returns engines to make sure DescribeDBInstances call returns
@@ -234,28 +254,28 @@ func auroraEngines() []string {
}
// rdsEngineFilter is a helper func to construct an RDS filter for engine names.
-func rdsEngineFilter(engines []string) []*rds.Filter {
- return []*rds.Filter{{
+func rdsEngineFilter(engines []string) []rdstypes.Filter {
+ return []rdstypes.Filter{{
Name: aws.String("engine"),
- Values: aws.StringSlice(engines),
+ Values: engines,
}}
}
// rdsClusterIDFilter is a helper func to construct an RDS DB Instances for returning Instances of a specific DB Cluster.
-func rdsClusterIDFilter(clusterIdentifier string) []*rds.Filter {
- return []*rds.Filter{{
+func rdsClusterIDFilter(clusterIdentifier string) []rdstypes.Filter {
+ return []rdstypes.Filter{{
Name: aws.String("db-cluster-id"),
- Values: aws.StringSlice([]string{clusterIdentifier}),
+ Values: []string{clusterIdentifier},
}}
}
// rdsEmptyFilter is a helper func to construct an empty RDS filter.
-func rdsEmptyFilter() []*rds.Filter {
- return []*rds.Filter{}
+func rdsEmptyFilter() []rdstypes.Filter {
+ return []rdstypes.Filter{}
}
// rdsFilterFn is a function that takes RDS filters and performs some operation with them, returning any error encountered.
-type rdsFilterFn func([]*rds.Filter) error
+type rdsFilterFn func([]rdstypes.Filter) error
// retryWithIndividualEngineFilters is a helper error handling function for AWS RDS unrecognized engine name filter errors,
// that will call the provided RDS querying function with filters, check the returned error,
diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go
index dde1a1a189940..59adf7f7f5b88 100644
--- a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go
+++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go
@@ -21,14 +21,14 @@ package db
import (
"context"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/rds"
- "github.com/aws/aws-sdk-go/service/rds/rdsiface"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
libcloudaws "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
@@ -47,56 +47,57 @@ func (f *rdsDBProxyPlugin) ComponentShortName() string {
// GetDatabases returns a list of database resources representing RDS
// Proxies and custom endpoints.
func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) {
- rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region,
- cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID),
- cloud.WithCredentialsMaybeIntegration(cfg.Integration),
+ awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region,
+ awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID),
+ awsconfig.WithCredentialsMaybeIntegration(cfg.Integration),
)
if err != nil {
return nil, trace.Wrap(err)
}
+ clt := cfg.awsClients.GetRDSClient(awsCfg)
// Get a list of all RDS Proxies. Each RDS Proxy has one "default"
// endpoint.
- rdsProxies, err := getRDSProxies(ctx, rdsClient, maxAWSPages)
+ rdsProxies, err := getRDSProxies(ctx, clt, maxAWSPages)
if err != nil {
return nil, trace.Wrap(err)
}
// Get all RDS Proxy custom endpoints sorted by the name of the RDS Proxy
// that owns the custom endpoints.
- customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, rdsClient, maxAWSPages)
+ customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, clt, maxAWSPages)
if err != nil {
cfg.Logger.DebugContext(ctx, "Failed to get RDS Proxy endpoints", "error", err)
}
var databases types.Databases
for _, dbProxy := range rdsProxies {
- if !aws.BoolValue(dbProxy.RequireTLS) {
- cfg.Logger.DebugContext(ctx, "Skipping RDS Proxy that doesn't support TLS", "rds_proxy", aws.StringValue(dbProxy.DBProxyName))
+ if !aws.ToBool(dbProxy.RequireTLS) {
+ cfg.Logger.DebugContext(ctx, "Skipping RDS Proxy that doesn't support TLS", "rds_proxy", aws.ToString(dbProxy.DBProxyName))
continue
}
- if !libcloudaws.IsRDSProxyAvailable(dbProxy) {
+ if !libcloudaws.IsRDSProxyAvailable(&dbProxy) {
cfg.Logger.DebugContext(ctx, "Skipping unavailable RDS Proxy",
- "rds_proxy", aws.StringValue(dbProxy.DBProxyName),
- "status", aws.StringValue(dbProxy.Status))
+ "rds_proxy", aws.ToString(dbProxy.DBProxyName),
+ "status", dbProxy.Status)
continue
}
- // rds.DBProxy has no tags information. An extra SDK call is made to
+ // rdstypes.DBProxy has no tags information. An extra SDK call is made to
// fetch the tags. If failed, keep going without the tags.
- tags, err := listRDSResourceTags(ctx, rdsClient, dbProxy.DBProxyArn)
+ tags, err := listRDSResourceTags(ctx, clt, dbProxy.DBProxyArn)
if err != nil {
cfg.Logger.DebugContext(ctx, "Failed to get tags for RDS Proxy",
- "rds_proxy", aws.StringValue(dbProxy.DBProxyName),
+ "rds_proxy", aws.ToString(dbProxy.DBProxyName),
"error", err,
)
}
// Add a database from RDS Proxy (default endpoint).
- database, err := common.NewDatabaseFromRDSProxy(dbProxy, tags)
+ database, err := common.NewDatabaseFromRDSProxy(&dbProxy, tags)
if err != nil {
cfg.Logger.DebugContext(ctx, "Could not convert RDS Proxy to database resource",
- "rds_proxy", aws.StringValue(dbProxy.DBProxyName),
+ "rds_proxy", aws.ToString(dbProxy.DBProxyName),
"error", err,
)
} else {
@@ -104,21 +105,21 @@ func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConf
}
// Add custom endpoints.
- for _, customEndpoint := range customEndpointsByProxyName[aws.StringValue(dbProxy.DBProxyName)] {
- if !libcloudaws.IsRDSProxyCustomEndpointAvailable(customEndpoint) {
+ for _, customEndpoint := range customEndpointsByProxyName[aws.ToString(dbProxy.DBProxyName)] {
+ if !libcloudaws.IsRDSProxyCustomEndpointAvailable(&customEndpoint) {
cfg.Logger.DebugContext(ctx, "Skipping unavailable custom endpoint of RDS Proxy",
- "endpoint", aws.StringValue(customEndpoint.DBProxyEndpointName),
- "rds_proxy", aws.StringValue(customEndpoint.DBProxyName),
- "status", aws.StringValue(customEndpoint.Status),
+ "endpoint", aws.ToString(customEndpoint.DBProxyEndpointName),
+ "rds_proxy", aws.ToString(customEndpoint.DBProxyName),
+ "status", customEndpoint.Status,
)
continue
}
- database, err = common.NewDatabaseFromRDSProxyCustomEndpoint(dbProxy, customEndpoint, tags)
+ database, err = common.NewDatabaseFromRDSProxyCustomEndpoint(&dbProxy, &customEndpoint, tags)
if err != nil {
cfg.Logger.DebugContext(ctx, "Could not convert custom endpoint for RDS Proxy to database resource",
- "endpoint", aws.StringValue(customEndpoint.DBProxyEndpointName),
- "rds_proxy", aws.StringValue(customEndpoint.DBProxyName),
+ "endpoint", aws.ToString(customEndpoint.DBProxyEndpointName),
+ "rds_proxy", aws.ToString(customEndpoint.DBProxyName),
"error", err,
)
continue
@@ -132,46 +133,54 @@ func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConf
// getRDSProxies fetches all RDS Proxies using the provided client, up to the
// specified max number of pages.
-func getRDSProxies(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (rdsProxies []*rds.DBProxy, err error) {
- var pageNum int
- err = rdsClient.DescribeDBProxiesPagesWithContext(
- ctx,
+func getRDSProxies(ctx context.Context, clt RDSClient, maxPages int) ([]rdstypes.DBProxy, error) {
+ pager := rds.NewDescribeDBProxiesPaginator(clt,
&rds.DescribeDBProxiesInput{},
- func(ddo *rds.DescribeDBProxiesOutput, lastPage bool) bool {
- pageNum++
- rdsProxies = append(rdsProxies, ddo.DBProxies...)
- return pageNum <= maxPages
+ func(dcpo *rds.DescribeDBProxiesPaginatorOptions) {
+ dcpo.StopOnDuplicateToken = true
},
)
- return rdsProxies, trace.Wrap(libcloudaws.ConvertRequestFailureError(err))
+
+ var rdsProxies []rdstypes.DBProxy
+ for i := 0; i < maxPages && pager.HasMorePages(); i++ {
+ page, err := pager.NextPage(ctx)
+ if err != nil {
+ return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err))
+ }
+ rdsProxies = append(rdsProxies, page.DBProxies...)
+ }
+ return rdsProxies, nil
}
// getRDSProxyCustomEndpoints fetches all RDS Proxy custom endpoints using the
// provided client.
-func getRDSProxyCustomEndpoints(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (map[string][]*rds.DBProxyEndpoint, error) {
- customEndpointsByProxyName := make(map[string][]*rds.DBProxyEndpoint)
- var pageNum int
- err := rdsClient.DescribeDBProxyEndpointsPagesWithContext(
- ctx,
+func getRDSProxyCustomEndpoints(ctx context.Context, clt RDSClient, maxPages int) (map[string][]rdstypes.DBProxyEndpoint, error) {
+ customEndpointsByProxyName := make(map[string][]rdstypes.DBProxyEndpoint)
+ pager := rds.NewDescribeDBProxyEndpointsPaginator(clt,
&rds.DescribeDBProxyEndpointsInput{},
- func(ddo *rds.DescribeDBProxyEndpointsOutput, lastPage bool) bool {
- pageNum++
- for _, customEndpoint := range ddo.DBProxyEndpoints {
- customEndpointsByProxyName[aws.StringValue(customEndpoint.DBProxyName)] = append(customEndpointsByProxyName[aws.StringValue(customEndpoint.DBProxyName)], customEndpoint)
- }
- return pageNum <= maxPages
+ func(ddepo *rds.DescribeDBProxyEndpointsPaginatorOptions) {
+ ddepo.StopOnDuplicateToken = true
},
)
- return customEndpointsByProxyName, trace.Wrap(libcloudaws.ConvertRequestFailureError(err))
+ for i := 0; i < maxPages && pager.HasMorePages(); i++ {
+ page, err := pager.NextPage(ctx)
+ if err != nil {
+ return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err))
+ }
+ for _, customEndpoint := range page.DBProxyEndpoints {
+ customEndpointsByProxyName[aws.ToString(customEndpoint.DBProxyName)] = append(customEndpointsByProxyName[aws.ToString(customEndpoint.DBProxyName)], customEndpoint)
+ }
+ }
+ return customEndpointsByProxyName, nil
}
// listRDSResourceTags returns tags for provided RDS resource.
-func listRDSResourceTags(ctx context.Context, rdsClient rdsiface.RDSAPI, resourceName *string) ([]*rds.Tag, error) {
- output, err := rdsClient.ListTagsForResourceWithContext(ctx, &rds.ListTagsForResourceInput{
+func listRDSResourceTags(ctx context.Context, clt RDSClient, resourceName *string) ([]rdstypes.Tag, error) {
+ output, err := clt.ListTagsForResource(ctx, &rds.ListTagsForResourceInput{
ResourceName: resourceName,
})
if err != nil {
- return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err))
+ return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err))
}
return output.TagList, nil
}
diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go
index b92ff2a439eda..99af538f590f4 100644
--- a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go
+++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go
@@ -21,11 +21,10 @@ package db
import (
"testing"
- "github.com/aws/aws-sdk-go/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
@@ -41,10 +40,12 @@ func TestRDSDBProxyFetcher(t *testing.T) {
tests := []awsFetcherTest{
{
name: "fetch all",
- inputClients: &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2},
- DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBProxies: []rdstypes.DBProxy{*rdsProxyVpc1, *rdsProxyVpc2},
+ DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyEndpointVpc1, *rdsProxyEndpointVpc2},
+ },
},
},
inputMatchers: makeAWSMatchersForType(types.AWSMatcherRDSProxy, "us-east-1", wildcardLabels),
@@ -52,10 +53,12 @@ func TestRDSDBProxyFetcher(t *testing.T) {
},
{
name: "fetch vpc1",
- inputClients: &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2},
- DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBProxies: []rdstypes.DBProxy{*rdsProxyVpc1, *rdsProxyVpc2},
+ DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyEndpointVpc1, *rdsProxyEndpointVpc2},
+ },
},
},
inputMatchers: makeAWSMatchersForType(types.AWSMatcherRDSProxy, "us-east-1", map[string]string{"vpc-id": "vpc1"}),
@@ -65,7 +68,7 @@ func TestRDSDBProxyFetcher(t *testing.T) {
testAWSFetchers(t, tests...)
}
-func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types.Database) {
+func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rdstypes.DBProxy, types.Database) {
rdsProxy := mocks.RDSProxy(name, region, vpcID)
rdsProxyDatabase, err := common.NewDatabaseFromRDSProxy(rdsProxy, nil)
require.NoError(t, err)
@@ -73,7 +76,7 @@ func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types
return rdsProxy, rdsProxyDatabase
}
-func makeRDSProxyCustomEndpoint(t *testing.T, rdsProxy *rds.DBProxy, name, region string) (*rds.DBProxyEndpoint, types.Database) {
+func makeRDSProxyCustomEndpoint(t *testing.T, rdsProxy *rdstypes.DBProxy, name, region string) (*rdstypes.DBProxyEndpoint, types.Database) {
rdsProxyEndpoint := mocks.RDSProxyCustomEndpoint(rdsProxy, name, region)
rdsProxyEndpointDatabase, err := common.NewDatabaseFromRDSProxyCustomEndpoint(rdsProxy, rdsProxyEndpoint, nil)
require.NoError(t, err)
diff --git a/lib/srv/discovery/fetchers/db/aws_rds_test.go b/lib/srv/discovery/fetchers/db/aws_rds_test.go
index 9dfc658268eeb..db4aeeb376cc3 100644
--- a/lib/srv/discovery/fetchers/db/aws_rds_test.go
+++ b/lib/srv/discovery/fetchers/db/aws_rds_test.go
@@ -21,13 +21,13 @@ package db
import (
"testing"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/rds"
- "github.com/aws/aws-sdk-go/service/rds/rdsiface"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
+ rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
+ "github.com/aws/aws-sdk-go-v2/service/redshift"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/discovery/common"
@@ -38,8 +38,8 @@ import (
func TestRDSFetchers(t *testing.T) {
t.Parallel()
- auroraMySQLEngine := &rds.DBEngineVersion{Engine: aws.String(services.RDSEngineAuroraMySQL)}
- postgresEngine := &rds.DBEngineVersion{Engine: aws.String(services.RDSEnginePostgres)}
+ auroraMySQLEngine := &rdstypes.DBEngineVersion{Engine: aws.String(services.RDSEngineAuroraMySQL)}
+ postgresEngine := &rdstypes.DBEngineVersion{Engine: aws.String(services.RDSEnginePostgres)}
rdsInstance1, rdsDatabase1 := makeRDSInstance(t, "instance-1", "us-east-1", envProdLabels)
rdsInstance2, rdsDatabase2 := makeRDSInstance(t, "instance-2", "us-east-2", envProdLabels)
@@ -58,19 +58,19 @@ func TestRDSFetchers(t *testing.T) {
tests := []awsFetcherTest{
{
name: "fetch all",
- inputClients: &cloud.TestCloudClients{
- RDSPerRegion: map[string]rdsiface.RDSAPI{
- "us-east-1": &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance},
- DBClusters: []*rds.DBCluster{auroraCluster1},
- DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: newRegionalFakeRDSClientProvider(map[string]RDSClient{
+ "us-east-1": &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster1},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine},
},
- "us-east-2": &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance},
- DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3},
- DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine},
+ "us-east-2": &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine},
},
- },
+ }),
},
inputMatchers: []types.AWSMatcher{
{
@@ -91,19 +91,19 @@ func TestRDSFetchers(t *testing.T) {
},
{
name: "fetch different labels for different regions",
- inputClients: &cloud.TestCloudClients{
- RDSPerRegion: map[string]rdsiface.RDSAPI{
- "us-east-1": &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance},
- DBClusters: []*rds.DBCluster{auroraCluster1},
- DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: newRegionalFakeRDSClientProvider(map[string]RDSClient{
+ "us-east-1": &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster1},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine},
},
- "us-east-2": &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance},
- DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3},
- DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine},
+ "us-east-2": &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine},
},
- },
+ }),
},
inputMatchers: []types.AWSMatcher{
{
@@ -124,19 +124,19 @@ func TestRDSFetchers(t *testing.T) {
},
{
name: "skip unrecognized engines",
- inputClients: &cloud.TestCloudClients{
- RDSPerRegion: map[string]rdsiface.RDSAPI{
- "us-east-1": &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance},
- DBClusters: []*rds.DBCluster{auroraCluster1},
- DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: newRegionalFakeRDSClientProvider(map[string]RDSClient{
+ "us-east-1": &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster1},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine},
},
- "us-east-2": &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance},
- DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3},
- DBEngineVersions: []*rds.DBEngineVersion{postgresEngine},
+ "us-east-2": &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*postgresEngine},
},
- },
+ }),
},
inputMatchers: []types.AWSMatcher{
{
@@ -154,14 +154,14 @@ func TestRDSFetchers(t *testing.T) {
},
{
name: "skip unsupported databases",
- inputClients: &cloud.TestCloudClients{
- RDSPerRegion: map[string]rdsiface.RDSAPI{
- "us-east-1": &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{auroraCluster1MemberInstance},
- DBClusters: []*rds.DBCluster{auroraCluster1, auroraClusterUnsupported},
- DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: newRegionalFakeRDSClientProvider(map[string]RDSClient{
+ "us-east-1": &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*auroraCluster1MemberInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster1, *auroraClusterUnsupported},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine},
},
- },
+ }),
},
inputMatchers: []types.AWSMatcher{{
Types: []string{types.AWSMatcherRDS},
@@ -172,11 +172,13 @@ func TestRDSFetchers(t *testing.T) {
},
{
name: "skip unavailable databases",
- inputClients: &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstanceUnavailable, rdsInstanceUnknownStatus, auroraCluster1MemberInstance, auroraClusterUnknownStatusMemberInstance},
- DBClusters: []*rds.DBCluster{auroraCluster1, auroraClusterUnavailable, auroraClusterUnknownStatus},
- DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstanceUnavailable, *rdsInstanceUnknownStatus, *auroraCluster1MemberInstance, *auroraClusterUnknownStatusMemberInstance},
+ DBClusters: []rdstypes.DBCluster{*auroraCluster1, *auroraClusterUnavailable, *auroraClusterUnknownStatus},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine},
+ },
},
},
inputMatchers: []types.AWSMatcher{{
@@ -188,11 +190,13 @@ func TestRDSFetchers(t *testing.T) {
},
{
name: "Aurora cluster without writer",
- inputClients: &cloud.TestCloudClients{
- RDS: &mocks.RDSMock{
- DBClusters: []*rds.DBCluster{auroraClusterNoWriter},
- DBInstances: []*rds.DBInstance{auroraClusterMemberNoWriter},
- DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine},
+ fetcherCfg: AWSFetcherFactoryConfig{
+ AWSClients: fakeAWSClients{
+ rdsClient: &mocks.RDSClient{
+ DBClusters: []rdstypes.DBCluster{*auroraClusterNoWriter},
+ DBInstances: []rdstypes.DBInstance{*auroraClusterMemberNoWriter},
+ DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine},
+ },
},
},
inputMatchers: []types.AWSMatcher{{
@@ -206,7 +210,7 @@ func TestRDSFetchers(t *testing.T) {
testAWSFetchers(t, tests...)
}
-func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) (*rds.DBInstance, types.Database) {
+func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rdstypes.DBInstance)) (*rdstypes.DBInstance, types.Database) {
instance := mocks.RDSInstance(name, region, labels, opts...)
database, err := common.NewDatabaseFromRDSInstance(instance)
require.NoError(t, err)
@@ -214,21 +218,21 @@ func makeRDSInstance(t *testing.T, name, region string, labels map[string]string
return instance, database
}
-func makeRDSCluster(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) (*rds.DBCluster, *rds.DBInstance, types.Database) {
+func makeRDSCluster(t *testing.T, name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) (*rdstypes.DBCluster, *rdstypes.DBInstance, types.Database) {
cluster := mocks.RDSCluster(name, region, labels, opts...)
dbInstanceMember := makeRDSMemberForCluster(t, name, region, "vpc-123", *cluster.Engine, labels)
- database, err := common.NewDatabaseFromRDSCluster(cluster, []*rds.DBInstance{dbInstanceMember})
+ database, err := common.NewDatabaseFromRDSCluster(cluster, []rdstypes.DBInstance{*dbInstanceMember})
require.NoError(t, err)
common.ApplyAWSDatabaseNameSuffix(database, types.AWSMatcherRDS)
return cluster, dbInstanceMember, database
}
-func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, labels map[string]string) *rds.DBInstance {
- instanceRDSMember, _ := makeRDSInstance(t, name+"-instance-1", region, labels, func(d *rds.DBInstance) {
+func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, labels map[string]string) *rdstypes.DBInstance {
+ instanceRDSMember, _ := makeRDSInstance(t, name+"-instance-1", region, labels, func(d *rdstypes.DBInstance) {
if d.DBSubnetGroup == nil {
- d.DBSubnetGroup = &rds.DBSubnetGroup{}
+ d.DBSubnetGroup = &rdstypes.DBSubnetGroup{}
}
- d.DBSubnetGroup.SetVpcId(vpcid)
+ d.DBSubnetGroup.VpcId = aws.String(vpcid)
d.DBClusterIdentifier = aws.String(name)
d.Engine = aws.String(engine)
})
@@ -236,9 +240,9 @@ func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, l
return instanceRDSMember
}
-func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels map[string]string, hasWriter bool) (*rds.DBCluster, *rds.DBInstance, types.Databases) {
+func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels map[string]string, hasWriter bool) (*rdstypes.DBCluster, *rdstypes.DBInstance, types.Databases) {
cluster := mocks.RDSCluster(name, region, labels,
- func(cluster *rds.DBCluster) {
+ func(cluster *rdstypes.DBCluster) {
// Disable writer by default. If hasWriter, writer endpoint will be added below.
cluster.DBClusterMembers = nil
},
@@ -249,11 +253,11 @@ func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels
var databases types.Databases
- instanceRDSMember := makeRDSMemberForCluster(t, name, region, "vpc-123", aws.StringValue(cluster.Engine), labels)
- dbInstanceMembers := []*rds.DBInstance{instanceRDSMember}
+ instanceRDSMember := makeRDSMemberForCluster(t, name, region, "vpc-123", aws.ToString(cluster.Engine), labels)
+ dbInstanceMembers := []rdstypes.DBInstance{*instanceRDSMember}
if hasWriter {
- cluster.DBClusterMembers = append(cluster.DBClusterMembers, &rds.DBClusterMember{
+ cluster.DBClusterMembers = append(cluster.DBClusterMembers, rdstypes.DBClusterMember{
IsClusterWriter: aws.Bool(true), // Add writer.
})
@@ -277,22 +281,49 @@ func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels
}
// withRDSInstanceStatus returns an option function for makeRDSInstance to overwrite status.
-func withRDSInstanceStatus(status string) func(*rds.DBInstance) {
- return func(instance *rds.DBInstance) {
+func withRDSInstanceStatus(status string) func(*rdstypes.DBInstance) {
+ return func(instance *rdstypes.DBInstance) {
instance.DBInstanceStatus = aws.String(status)
}
}
// withRDSClusterEngineMode returns an option function for makeRDSCluster to overwrite engine mode.
-func withRDSClusterEngineMode(mode string) func(*rds.DBCluster) {
- return func(cluster *rds.DBCluster) {
+func withRDSClusterEngineMode(mode string) func(*rdstypes.DBCluster) {
+ return func(cluster *rdstypes.DBCluster) {
cluster.EngineMode = aws.String(mode)
}
}
// withRDSClusterStatus returns an option function for makeRDSCluster to overwrite status.
-func withRDSClusterStatus(status string) func(*rds.DBCluster) {
- return func(cluster *rds.DBCluster) {
+func withRDSClusterStatus(status string) func(*rdstypes.DBCluster) {
+ return func(cluster *rdstypes.DBCluster) {
cluster.Status = aws.String(status)
}
}
+
+// provides a client specific to each region, where the map keys are regions.
+func newRegionalFakeRDSClientProvider(cs map[string]RDSClient) fakeRegionalRDSClients {
+ return fakeRegionalRDSClients{rdsClients: cs}
+}
+
+type fakeAWSClients struct {
+ rdsClient RDSClient
+ redshiftClient RedshiftClient
+}
+
+func (f fakeAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient {
+ return f.rdsClient
+}
+
+func (f fakeAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient {
+ return f.redshiftClient
+}
+
+type fakeRegionalRDSClients struct {
+ AWSClientProvider
+ rdsClients map[string]RDSClient
+}
+
+func (f fakeRegionalRDSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient {
+ return f.rdsClients[cfg.Region]
+}
diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go
index 0cda0b478e67b..b6a17f32ede5e 100644
--- a/lib/srv/discovery/fetchers/db/aws_redshift.go
+++ b/lib/srv/discovery/fetchers/db/aws_redshift.go
@@ -32,9 +32,6 @@ import (
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
-// RedshiftClientProviderFunc provides a [RedshiftClient].
-type RedshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient
-
// RedshiftClient is a subset of the AWS Redshift API.
type RedshiftClient interface {
redshift.DescribeClustersAPIClient
@@ -57,7 +54,7 @@ func (f *redshiftPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig
if err != nil {
return nil, trace.Wrap(err)
}
- clusters, err := getRedshiftClusters(ctx, cfg.redshiftClientProviderFn(awsCfg))
+ clusters, err := getRedshiftClusters(ctx, cfg.awsClients.GetRedshiftClient(awsCfg))
if err != nil {
return nil, trace.Wrap(err)
}
diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_test.go
index ded47035e96e3..8e95641f7931c 100644
--- a/lib/srv/discovery/fetchers/db/aws_redshift_test.go
+++ b/lib/srv/discovery/fetchers/db/aws_redshift_test.go
@@ -22,7 +22,6 @@ import (
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/redshift"
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/stretchr/testify/require"
@@ -31,11 +30,6 @@ import (
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
-func newFakeRedshiftClientProvider(c RedshiftClient) RedshiftClientProviderFunc {
- return func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient {
- return c
- }
-}
func TestRedshiftFetcher(t *testing.T) {
t.Parallel()
@@ -48,9 +42,11 @@ func TestRedshiftFetcher(t *testing.T) {
{
name: "fetch all",
fetcherCfg: AWSFetcherFactoryConfig{
- RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{
- Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev},
- }),
+ AWSClients: fakeAWSClients{
+ redshiftClient: &mocks.RedshiftClient{
+ Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev},
+ },
+ },
},
inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUse1Dev},
@@ -58,9 +54,11 @@ func TestRedshiftFetcher(t *testing.T) {
{
name: "fetch prod",
fetcherCfg: AWSFetcherFactoryConfig{
- RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{
- Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev},
- }),
+ AWSClients: fakeAWSClients{
+ redshiftClient: &mocks.RedshiftClient{
+ Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev},
+ },
+ },
},
inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", envProdLabels),
wantDatabases: types.Databases{redshiftDatabaseUse1Prod},
@@ -68,9 +66,11 @@ func TestRedshiftFetcher(t *testing.T) {
{
name: "skip unavailable",
fetcherCfg: AWSFetcherFactoryConfig{
- RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{
- Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Unavailable, *redshiftUse1UnknownStatus},
- }),
+ AWSClients: fakeAWSClients{
+ redshiftClient: &mocks.RedshiftClient{
+ Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Unavailable, *redshiftUse1UnknownStatus},
+ },
+ },
},
inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUnknownStatus},
diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go
index 8d79bc2bb65bc..cd4df7269a14e 100644
--- a/lib/srv/discovery/fetchers/db/db.go
+++ b/lib/srv/discovery/fetchers/db/db.go
@@ -23,6 +23,7 @@ import (
"log/slog"
"github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/rds"
"github.com/aws/aws-sdk-go-v2/service/redshift"
"github.com/gravitational/trace"
"golang.org/x/exp/maps"
@@ -67,14 +68,32 @@ func IsAzureMatcherType(matcherType string) bool {
return len(makeAzureFetcherFuncs[matcherType]) > 0
}
+// AWSClientProvider provides AWS service API clients.
+type AWSClientProvider interface {
+ // GetRDSClient provides an [RDSClient].
+ GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient
+ // GetRedshiftClient provides an [RedshiftClient].
+ GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient
+}
+
+type defaultAWSClients struct{}
+
+func (defaultAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient {
+ return rds.NewFromConfig(cfg, optFns...)
+}
+
+func (defaultAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient {
+ return redshift.NewFromConfig(cfg, optFns...)
+}
+
// AWSFetcherFactoryConfig is the configuration for an [AWSFetcherFactory].
type AWSFetcherFactoryConfig struct {
// AWSConfigProvider provides [aws.Config] for AWS SDK service clients.
AWSConfigProvider awsconfig.Provider
+ // AWSClients provides AWS SDK clients.
+ AWSClients AWSClientProvider
// CloudClients is an interface for retrieving AWS SDK v1 cloud clients.
CloudClients cloud.AWSClients
- // RedshiftClientProviderFn is an optional function that provides
- RedshiftClientProviderFn RedshiftClientProviderFunc
}
func (c *AWSFetcherFactoryConfig) checkAndSetDefaults() error {
@@ -84,10 +103,8 @@ func (c *AWSFetcherFactoryConfig) checkAndSetDefaults() error {
if c.AWSConfigProvider == nil {
return trace.BadParameter("missing AWSConfigProvider")
}
- if c.RedshiftClientProviderFn == nil {
- c.RedshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient {
- return redshift.NewFromConfig(cfg, optFns...)
- }
+ if c.AWSClients == nil {
+ c.AWSClients = defaultAWSClients{}
}
return nil
}
@@ -125,15 +142,15 @@ func (f *AWSFetcherFactory) MakeFetchers(ctx context.Context, matchers []types.A
for _, makeFetcher := range makeFetchers {
for _, region := range matcher.Regions {
fetcher, err := makeFetcher(awsFetcherConfig{
- AWSClients: f.cfg.CloudClients,
- Type: matcherType,
- AssumeRole: assumeRole,
- Labels: matcher.Tags,
- Region: region,
- Integration: matcher.Integration,
- DiscoveryConfigName: discoveryConfigName,
- AWSConfigProvider: f.cfg.AWSConfigProvider,
- redshiftClientProviderFn: f.cfg.RedshiftClientProviderFn,
+ AWSClients: f.cfg.CloudClients,
+ Type: matcherType,
+ AssumeRole: assumeRole,
+ Labels: matcher.Tags,
+ Region: region,
+ Integration: matcher.Integration,
+ DiscoveryConfigName: discoveryConfigName,
+ AWSConfigProvider: f.cfg.AWSConfigProvider,
+ awsClients: f.cfg.AWSClients,
})
if err != nil {
return nil, trace.Wrap(err)
diff --git a/lib/srv/discovery/fetchers/eks.go b/lib/srv/discovery/fetchers/eks.go
index 193244bba75e3..27dcbdd2d83fd 100644
--- a/lib/srv/discovery/fetchers/eks.go
+++ b/lib/srv/discovery/fetchers/eks.go
@@ -29,13 +29,12 @@ import (
"sync"
"time"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/arn"
- "github.com/aws/aws-sdk-go/service/eks"
- "github.com/aws/aws-sdk-go/service/eks/eksiface"
- "github.com/aws/aws-sdk-go/service/iam"
- "github.com/aws/aws-sdk-go/service/sts"
- "github.com/aws/aws-sdk-go/service/sts/stsiface"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws/arn"
+ "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"golang.org/x/sync/errgroup"
@@ -48,8 +47,8 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
awslib "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/fixtures"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/services"
@@ -63,24 +62,48 @@ const (
type eksFetcher struct {
EKSFetcherConfig
- mu sync.Mutex
- client eksiface.EKSAPI
- stsClient stsiface.STSAPI
- callerIdentity string
+ mu sync.Mutex
+ client EKSClient
+ stsPresignClient STSPresignClient
+ callerIdentity string
}
-// ClientGetter is an interface for getting an EKS client and an STS client.
-type ClientGetter interface {
- // GetAWSEKSClient returns AWS EKS client for the specified region.
- GetAWSEKSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (eksiface.EKSAPI, error)
- // GetAWSSTSClient returns AWS STS client for the specified region.
- GetAWSSTSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (stsiface.STSAPI, error)
+// EKSClient is the subset of the EKS interface we use in fetchers.
+type EKSClient interface {
+ eks.DescribeClusterAPIClient
+ eks.ListClustersAPIClient
+
+ AssociateAccessPolicy(ctx context.Context, params *eks.AssociateAccessPolicyInput, optFns ...func(*eks.Options)) (*eks.AssociateAccessPolicyOutput, error)
+ CreateAccessEntry(ctx context.Context, params *eks.CreateAccessEntryInput, optFns ...func(*eks.Options)) (*eks.CreateAccessEntryOutput, error)
+ DeleteAccessEntry(ctx context.Context, params *eks.DeleteAccessEntryInput, optFns ...func(*eks.Options)) (*eks.DeleteAccessEntryOutput, error)
+ DescribeAccessEntry(ctx context.Context, params *eks.DescribeAccessEntryInput, optFns ...func(*eks.Options)) (*eks.DescribeAccessEntryOutput, error)
+ UpdateAccessEntry(ctx context.Context, params *eks.UpdateAccessEntryInput, optFns ...func(*eks.Options)) (*eks.UpdateAccessEntryOutput, error)
+}
+
+// STSClient is the subset of the STS interface we use in fetchers.
+type STSClient interface {
+ GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
+ stscreds.AssumeRoleAPIClient
+}
+
+// STSPresignClient is the subset of the STS presign interface we use in fetchers.
+type STSPresignClient = kubeutils.STSPresignClient
+
+// AWSClientGetter is an interface for getting an EKS client and an STS client.
+type AWSClientGetter interface {
+ awsconfig.Provider
+ // GetAWSEKSClient returns AWS EKS client for the specified config.
+ GetAWSEKSClient(aws.Config) EKSClient
+ // GetAWSSTSClient returns AWS STS client for the specified config.
+ GetAWSSTSClient(aws.Config) STSClient
+ // GetAWSSTSPresignClient returns AWS STS presign client for the specified config.
+ GetAWSSTSPresignClient(aws.Config) STSPresignClient
}
// EKSFetcherConfig configures the EKS fetcher.
type EKSFetcherConfig struct {
// ClientGetter retrieves an EKS client and an STS client.
- ClientGetter ClientGetter
+ ClientGetter AWSClientGetter
// AssumeRole provides a role ARN and ExternalID to assume an AWS role
// when fetching clusters.
AssumeRole types.AssumeRole
@@ -133,7 +156,7 @@ func (c *EKSFetcherConfig) CheckAndSetDefaults() error {
// MakeEKSFetchersFromAWSMatchers creates fetchers from the provided matchers. Returned fetchers are separated
// by their reliance on the integration.
-func MakeEKSFetchersFromAWSMatchers(logger *slog.Logger, clients cloud.AWSClients, matchers []types.AWSMatcher, discoveryConfigName string) (kubeFetchers []common.Fetcher, _ error) {
+func MakeEKSFetchersFromAWSMatchers(logger *slog.Logger, clients AWSClientGetter, matchers []types.AWSMatcher, discoveryConfigName string) (kubeFetchers []common.Fetcher, _ error) {
for _, matcher := range matchers {
var matcherAssumeRole types.AssumeRole
if matcher.AssumeRole != nil {
@@ -162,7 +185,8 @@ func MakeEKSFetchersFromAWSMatchers(logger *slog.Logger, clients cloud.AWSClient
"error", err,
"region", region,
"labels", matcher.Tags,
- "assume_role", matcherAssumeRole.RoleARN)
+ "assume_role", matcherAssumeRole.RoleARN,
+ )
continue
}
kubeFetchers = append(kubeFetchers, fetcher)
@@ -197,7 +221,7 @@ func NewEKSFetcher(cfg EKSFetcherConfig) (common.Fetcher, error) {
return fetcher, nil
}
-func (a *eksFetcher) getClient(ctx context.Context) (eksiface.EKSAPI, error) {
+func (a *eksFetcher) getClient(ctx context.Context) (EKSClient, error) {
a.mu.Lock()
defer a.mu.Unlock()
@@ -205,16 +229,12 @@ func (a *eksFetcher) getClient(ctx context.Context) (eksiface.EKSAPI, error) {
return a.client, nil
}
- client, err := a.ClientGetter.GetAWSEKSClient(
- ctx,
- a.Region,
- a.getAWSOpts()...,
- )
+ cfg, err := a.ClientGetter.GetConfig(ctx, a.Region, a.getAWSOpts()...)
if err != nil {
return nil, trace.Wrap(err)
}
- a.client = client
+ a.client = a.ClientGetter.GetAWSEKSClient(cfg)
return a.client, nil
}
@@ -280,39 +300,38 @@ func (a *eksFetcher) getEKSClusters(ctx context.Context) (types.KubeClusters, er
return nil, trace.Wrap(err, "failed getting AWS EKS client")
}
- err = client.ListClustersPagesWithContext(ctx,
- &eks.ListClustersInput{
- Include: nil, // For now we should only list EKS clusters
- },
- func(clustersList *eks.ListClustersOutput, _ bool) bool {
- for i := 0; i < len(clustersList.Clusters); i++ {
- clusterName := aws.StringValue(clustersList.Clusters[i])
- // group.Go will block if the concurrency limit is reached.
- // It will resume once any running function finishes.
- group.Go(func() error {
- cluster, err := a.getMatchingKubeCluster(groupCtx, clusterName)
- // trace.CompareFailed is returned if the cluster did not match the matcher filtering labels
- // or if the cluster is not yet active.
- if trace.IsCompareFailed(err) {
- a.Logger.DebugContext(groupCtx, "Cluster did not match the filtering criteria", "error", err, "cluster", clusterName)
- // never return an error otherwise we will impact discovery process
- return nil
- } else if err != nil {
- a.Logger.WarnContext(groupCtx, "Failed to discover EKS cluster", "error", err, "cluster", clusterName)
- // never return an error otherwise we will impact discovery process
- return nil
- }
-
- mu.Lock()
- defer mu.Unlock()
- clusters = append(clusters, cluster)
+ // For now we should only list EKS clusters so we use nil (default) input param.
+ for p := eks.NewListClustersPaginator(client, nil); p.HasMorePages(); {
+ out, err := p.NextPage(ctx)
+ if err != nil {
+ return clusters, trace.Wrap(err)
+ }
+ for _, clusterName := range out.Clusters {
+ // group.Go will block if the concurrency limit is reached.
+ // It will resume once any running function finishes.
+ group.Go(func() error {
+ cluster, err := a.getMatchingKubeCluster(groupCtx, clusterName)
+ // trace.CompareFailed is returned if the cluster did not match the matcher filtering labels
+ // or if the cluster is not yet active.
+ if trace.IsCompareFailed(err) {
+ a.Logger.DebugContext(groupCtx, "Cluster did not match the filtering criteria", "error", err, "cluster", clusterName)
+ // never return an error otherwise we will impact discovery process
return nil
- })
- }
- return true
- },
- )
- // error can be discarded since we do not return any error from group.Go closure.
+ } else if err != nil {
+ a.Logger.WarnContext(groupCtx, "Failed to discover EKS cluster", "error", err, "cluster", clusterName)
+ // never return an error otherwise we will impact discovery process
+ return nil
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+ clusters = append(clusters, cluster)
+ return nil
+ })
+ }
+ }
+
+ // The error can be discarded since we do not return any error from group.Go closure.
_ = group.Wait()
return clusters, trace.Wrap(err)
}
@@ -352,7 +371,7 @@ func (a *eksFetcher) getMatchingKubeCluster(ctx context.Context, clusterName str
return nil, trace.Wrap(err, "failed getting AWS EKS client")
}
- rsp, err := client.DescribeClusterWithContext(
+ rsp, err := client.DescribeCluster(
ctx,
&eks.DescribeClusterInput{
Name: aws.String(clusterName),
@@ -362,14 +381,14 @@ func (a *eksFetcher) getMatchingKubeCluster(ctx context.Context, clusterName str
return nil, trace.WrapWithMessage(err, "Unable to describe EKS cluster %q", clusterName)
}
- switch st := aws.StringValue(rsp.Cluster.Status); st {
- case eks.ClusterStatusUpdating, eks.ClusterStatusActive:
+ switch st := rsp.Cluster.Status; st {
+ case ekstypes.ClusterStatusUpdating, ekstypes.ClusterStatusActive:
a.Logger.DebugContext(ctx, "EKS cluster status is valid", "status", st, "cluster", clusterName)
default:
return nil, trace.CompareFailed("EKS cluster %q not enrolled due to its current status: %s", clusterName, st)
}
- cluster, err := common.NewKubeClusterFromAWSEKS(aws.StringValue(rsp.Cluster.Name), aws.StringValue(rsp.Cluster.Arn), rsp.Cluster.Tags)
+ cluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(rsp.Cluster.Name), aws.ToString(rsp.Cluster.Arn), rsp.Cluster.Tags)
if err != nil {
return nil, trace.WrapWithMessage(err, "Unable to convert eks.Cluster cluster into types.KubernetesClusterV3.")
}
@@ -388,8 +407,8 @@ func (a *eksFetcher) getMatchingKubeCluster(ctx context.Context, clusterName str
// If the fetcher should setup access for the specified ARN, first check if the cluster authentication mode
// is set to either [eks.AuthenticationModeApi] or [eks.AuthenticationModeApiAndConfigMap].
// If the authentication mode is set to [eks.AuthenticationModeConfigMap], the fetcher will ignore the cluster.
- switch st := aws.StringValue(rsp.Cluster.AccessConfig.AuthenticationMode); st {
- case eks.AuthenticationModeApiAndConfigMap, eks.AuthenticationModeApi:
+ switch st := rsp.Cluster.AccessConfig.AuthenticationMode; st {
+ case ekstypes.AuthenticationModeApiAndConfigMap, ekstypes.AuthenticationModeApi:
if err := a.checkOrSetupAccessForARN(ctx, client, rsp.Cluster); err != nil {
return nil, trace.Wrap(err, "unable to setup access for EKS cluster %q", clusterName)
}
@@ -427,9 +446,9 @@ var eksDiscoveryPermissions = []string{
// The check involves checking if the access entry exists and if the "teleport:kube-agent:eks" is part of the Kubernetes group.
// If the access entry doesn't exist or is misconfigured, the fetcher will temporarily gain admin access and create the role and binding.
// The fetcher will then upsert the access entry with the correct Kubernetes group.
-func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client eksiface.EKSAPI, cluster *eks.Cluster) error {
+func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client EKSClient, cluster *ekstypes.Cluster) error {
entry, err := convertAWSError(
- client.DescribeAccessEntryWithContext(ctx,
+ client.DescribeAccessEntry(ctx,
&eks.DescribeAccessEntryInput{
ClusterName: cluster.Name,
PrincipalArn: aws.String(a.SetupAccessForARN),
@@ -442,13 +461,13 @@ func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client eksifa
// Access denied means that the principal does not have access to setup access entries for the cluster.
a.Logger.WarnContext(ctx, "Access denied to setup access for EKS cluster, ensure the required permissions are set",
"error", err,
- "cluster", aws.StringValue(cluster.Name),
+ "cluster", aws.ToString(cluster.Name),
"required_permissions", eksDiscoveryPermissions,
)
return nil
case err == nil:
// If the access entry exists and the principal has access to the cluster, check if the teleportKubernetesGroup is part of the Kubernetes group.
- if entry.AccessEntry != nil && slices.Contains(aws.StringValueSlice(entry.AccessEntry.KubernetesGroups), teleportKubernetesGroup) {
+ if entry.AccessEntry != nil && slices.Contains(entry.AccessEntry.KubernetesGroups, teleportKubernetesGroup) {
return nil
}
fallthrough
@@ -459,12 +478,12 @@ func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client eksifa
// Access denied means that the principal does not have access to setup access entries for the cluster.
a.Logger.WarnContext(ctx, "Access denied to setup access for EKS cluster, ensure the required permissions are set",
"error", err,
- "cluster", aws.StringValue(cluster.Name),
+ "cluster", aws.ToString(cluster.Name),
"required_permissions", eksDiscoveryPermissions,
)
return nil
} else if err != nil {
- return trace.Wrap(err, "unable to setup access for EKS cluster %q", aws.StringValue(cluster.Name))
+ return trace.Wrap(err, "unable to setup access for EKS cluster %q", aws.ToString(cluster.Name))
}
// upsert the access entry with the correct Kubernetes group for the final
@@ -473,29 +492,29 @@ func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client eksifa
// Access denied means that the principal does not have access to setup access entries for the cluster.
a.Logger.WarnContext(ctx, "Access denied to setup access for EKS cluster, ensure the required permissions are set",
"error", err,
- "cluster", aws.StringValue(cluster.Name),
+ "cluster", aws.ToString(cluster.Name),
"required_permissions", eksDiscoveryPermissions,
)
return nil
}
- return trace.Wrap(err, "unable to setup access for EKS cluster %q", aws.StringValue(cluster.Name))
+ return trace.Wrap(err, "unable to setup access for EKS cluster %q", aws.ToString(cluster.Name))
default:
return trace.Wrap(err)
}
-
}
// temporarilyGainAdminAccessAndCreateRole temporarily gains admin access to the EKS cluster by associating the EKS Cluster Admin Policy
// to the callerIdentity. The fetcher will then create the role and binding for the teleportKubernetesGroup in the EKS cluster.
-func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context, client eksiface.EKSAPI, cluster *eks.Cluster) error {
+func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context, client EKSClient, cluster *ekstypes.Cluster) error {
const (
// https://docs.aws.amazon.com/eks/latest/userguide/access-policies.html
// We use cluster admin policy to create namespace and cluster role.
eksClusterAdminPolicy = "arn:aws:eks::aws:cluster-access-policy/AmazonEKSClusterAdminPolicy"
)
+
// Setup access for the ARN
rsp, err := convertAWSError(
- client.CreateAccessEntryWithContext(ctx,
+ client.CreateAccessEntry(ctx,
&eks.CreateAccessEntryInput{
ClusterName: cluster.Name,
PrincipalArn: aws.String(a.callerIdentity),
@@ -510,7 +529,7 @@ func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context
if rsp != nil {
defer func() {
_, err := convertAWSError(
- client.DeleteAccessEntryWithContext(
+ client.DeleteAccessEntry(
ctx,
&eks.DeleteAccessEntryInput{
ClusterName: cluster.Name,
@@ -520,18 +539,17 @@ func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context
if err != nil {
a.Logger.WarnContext(ctx, "Failed to delete access entry for EKS cluster",
"error", err,
- "cluster", aws.StringValue(cluster.Name),
+ "cluster", aws.ToString(cluster.Name),
)
}
}()
-
}
_, err = convertAWSError(
- client.AssociateAccessPolicyWithContext(ctx, &eks.AssociateAccessPolicyInput{
- AccessScope: &eks.AccessScope{
+ client.AssociateAccessPolicy(ctx, &eks.AssociateAccessPolicyInput{
+ AccessScope: &ekstypes.AccessScope{
Namespaces: nil,
- Type: aws.String(eks.AccessScopeTypeCluster),
+ Type: ekstypes.AccessScopeTypeCluster,
},
ClusterName: cluster.Name,
PolicyArn: aws.String(eksClusterAdminPolicy),
@@ -539,7 +557,7 @@ func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context
}),
)
if err != nil && !trace.IsAlreadyExists(err) {
- return trace.Wrap(err, "unable to associate EKS Access Policy to cluster %q", aws.StringValue(cluster.Name))
+ return trace.Wrap(err, "unable to associate EKS Access Policy to cluster %q", aws.ToString(cluster.Name))
}
timeout := a.Clock.NewTimer(60 * time.Second)
@@ -561,17 +579,19 @@ forLoop:
}
}
- return trace.Wrap(err, "unable to upsert role and binding for cluster %q", aws.StringValue(cluster.Name))
+ return trace.Wrap(err, "unable to upsert role and binding for cluster %q", aws.ToString(cluster.Name))
}
// upsertRoleAndBinding upserts the ClusterRole and ClusterRoleBinding for the teleportKubernetesGroup in the EKS cluster.
-func (a *eksFetcher) upsertRoleAndBinding(ctx context.Context, cluster *eks.Cluster) error {
- client, err := a.createKubeClient(cluster)
+func (a *eksFetcher) upsertRoleAndBinding(ctx context.Context, cluster *ekstypes.Cluster) error {
+ client, err := a.createKubeClient(ctx, cluster)
if err != nil {
- return trace.Wrap(err, "unable to create Kubernetes client for cluster %q", aws.StringValue(cluster.Name))
+ return trace.Wrap(err, "unable to create Kubernetes client for cluster %q", aws.ToString(cluster.Name))
}
+
ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
+
if err := a.upsertClusterRoleWithAdminCredentials(ctx, client); err != nil {
return trace.Wrap(err, "unable to upsert ClusterRole for group %q", teleportKubernetesGroup)
}
@@ -583,23 +603,23 @@ func (a *eksFetcher) upsertRoleAndBinding(ctx context.Context, cluster *eks.Clus
return nil
}
-func (a *eksFetcher) createKubeClient(cluster *eks.Cluster) (*kubernetes.Clientset, error) {
- if a.stsClient == nil {
- return nil, trace.BadParameter("STS client is not set")
+func (a *eksFetcher) createKubeClient(ctx context.Context, cluster *ekstypes.Cluster) (*kubernetes.Clientset, error) {
+ if a.stsPresignClient == nil {
+ return nil, trace.BadParameter("STS presign client is not set")
}
- token, _, err := kubeutils.GenAWSEKSToken(a.stsClient, aws.StringValue(cluster.Name), a.Clock)
+ token, _, err := kubeutils.GenAWSEKSToken(ctx, a.stsPresignClient, aws.ToString(cluster.Name), a.Clock)
if err != nil {
- return nil, trace.Wrap(err, "unable to generate EKS token for cluster %q", aws.StringValue(cluster.Name))
+ return nil, trace.Wrap(err, "unable to generate EKS token for cluster %q", aws.ToString(cluster.Name))
}
- ca, err := base64.StdEncoding.DecodeString(aws.StringValue(cluster.CertificateAuthority.Data))
+ ca, err := base64.StdEncoding.DecodeString(aws.ToString(cluster.CertificateAuthority.Data))
if err != nil {
- return nil, trace.Wrap(err, "unable to decode EKS cluster %q certificate authority", aws.StringValue(cluster.Name))
+ return nil, trace.Wrap(err, "unable to decode EKS cluster %q certificate authority", aws.ToString(cluster.Name))
}
- apiEndpoint := aws.StringValue(cluster.Endpoint)
+ apiEndpoint := aws.ToString(cluster.Endpoint)
if len(apiEndpoint) == 0 {
- return nil, trace.BadParameter("invalid api endpoint for cluster %q", aws.StringValue(cluster.Name))
+ return nil, trace.BadParameter("invalid api endpoint for cluster %q", aws.ToString(cluster.Name))
}
client, err := kubernetes.NewForConfig(
@@ -611,7 +631,7 @@ func (a *eksFetcher) createKubeClient(cluster *eks.Cluster) (*kubernetes.Clients
},
},
)
- return client, trace.Wrap(err, "unable to create Kubernetes client for cluster %q", aws.StringValue(cluster.Name))
+ return client, trace.Wrap(err, "unable to create Kubernetes client for cluster %q", aws.ToString(cluster.Name))
}
// upsertClusterRoleWithAdminCredentials tries to upsert the ClusterRole using admin credentials.
@@ -664,13 +684,13 @@ func (a *eksFetcher) upsertClusterRoleBindingWithAdminCredentials(ctx context.Co
}
// upsertAccessEntry upserts the access entry for the specified ARN with the teleportKubernetesGroup.
-func (a *eksFetcher) upsertAccessEntry(ctx context.Context, client eksiface.EKSAPI, cluster *eks.Cluster) error {
+func (a *eksFetcher) upsertAccessEntry(ctx context.Context, client EKSClient, cluster *ekstypes.Cluster) error {
_, err := convertAWSError(
- client.CreateAccessEntryWithContext(ctx,
+ client.CreateAccessEntry(ctx,
&eks.CreateAccessEntryInput{
ClusterName: cluster.Name,
PrincipalArn: aws.String(a.SetupAccessForARN),
- KubernetesGroups: aws.StringSlice([]string{teleportKubernetesGroup}),
+ KubernetesGroups: []string{teleportKubernetesGroup},
},
))
if err == nil || !trace.IsAlreadyExists(err) {
@@ -678,11 +698,11 @@ func (a *eksFetcher) upsertAccessEntry(ctx context.Context, client eksiface.EKSA
}
_, err = convertAWSError(
- client.UpdateAccessEntryWithContext(ctx,
+ client.UpdateAccessEntry(ctx,
&eks.UpdateAccessEntryInput{
ClusterName: cluster.Name,
PrincipalArn: aws.String(a.SetupAccessForARN),
- KubernetesGroups: aws.StringSlice([]string{teleportKubernetesGroup}),
+ KubernetesGroups: []string{teleportKubernetesGroup},
},
))
@@ -690,35 +710,35 @@ func (a *eksFetcher) upsertAccessEntry(ctx context.Context, client eksiface.EKSA
}
func (a *eksFetcher) setCallerIdentity(ctx context.Context) error {
- var err error
- a.stsClient, err = a.ClientGetter.GetAWSSTSClient(
- ctx,
+ cfg, err := a.ClientGetter.GetConfig(ctx,
a.Region,
a.getAWSOpts()...,
)
if err != nil {
return trace.Wrap(err)
}
-
+ a.stsPresignClient = a.ClientGetter.GetAWSSTSPresignClient(cfg)
if a.AssumeRole.RoleARN != "" {
a.callerIdentity = a.AssumeRole.RoleARN
return nil
}
- identity, err := a.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{})
+
+ stsClient := a.ClientGetter.GetAWSSTSClient(cfg)
+ identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return trace.Wrap(err)
}
- a.callerIdentity = convertAssumedRoleToIAMRole(aws.StringValue(identity.Arn))
+ a.callerIdentity = convertAssumedRoleToIAMRole(aws.ToString(identity.Arn))
return nil
}
-func (a *eksFetcher) getAWSOpts() []cloud.AWSOptionsFn {
- return []cloud.AWSOptionsFn{
- cloud.WithAssumeRole(
+func (a *eksFetcher) getAWSOpts() []awsconfig.OptionsFn {
+ return []awsconfig.OptionsFn{
+ awsconfig.WithAssumeRole(
a.AssumeRole.RoleARN,
a.AssumeRole.ExternalID,
),
- cloud.WithCredentialsMaybeIntegration(a.Integration),
+ awsconfig.WithCredentialsMaybeIntegration(a.Integration),
}
}
@@ -734,6 +754,7 @@ func convertAssumedRoleToIAMRole(callerIdentity string) string {
const (
assumeRolePrefix = "assumed-role/"
roleResource = "role"
+ serviceName = "iam"
)
a, err := arn.Parse(callerIdentity)
if err != nil {
@@ -742,7 +763,7 @@ func convertAssumedRoleToIAMRole(callerIdentity string) string {
if !strings.HasPrefix(a.Resource, assumeRolePrefix) {
return callerIdentity
}
- a.Service = iam.ServiceName
+ a.Service = serviceName
split := strings.Split(a.Resource, "/")
if len(split) <= 2 {
return callerIdentity
diff --git a/lib/srv/discovery/fetchers/eks_test.go b/lib/srv/discovery/fetchers/eks_test.go
index d7b9c6b4cac47..ad8c8667d2862 100644
--- a/lib/srv/discovery/fetchers/eks_test.go
+++ b/lib/srv/discovery/fetchers/eks_test.go
@@ -23,16 +23,16 @@ import (
"errors"
"testing"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/request"
- "github.com/aws/aws-sdk-go/service/eks"
- "github.com/aws/aws-sdk-go/service/eks/eksiface"
- "github.com/aws/aws-sdk-go/service/sts"
- "github.com/aws/aws-sdk-go/service/sts/stsiface"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
+ "github.com/aws/aws-sdk-go-v2/service/eks"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
+ "github.com/gravitational/teleport/lib/cloud/mocks"
+ kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/srv/discovery/common"
"github.com/gravitational/teleport/lib/utils"
)
@@ -43,9 +43,10 @@ func TestEKSFetcher(t *testing.T) {
filterLabels types.Labels
}
tests := []struct {
- name string
- args args
- want types.ResourcesWithLabels
+ name string
+ args args
+ assumeRole types.AssumeRole
+ want types.ResourcesWithLabels
}{
{
name: "list everything",
@@ -57,6 +58,17 @@ func TestEKSFetcher(t *testing.T) {
},
want: eksClustersToResources(t, eksMockClusters...),
},
+ {
+ name: "list everything with assumed role",
+ args: args{
+ region: types.Wildcard,
+ filterLabels: types.Labels{
+ types.Wildcard: []string{types.Wildcard},
+ },
+ },
+ assumeRole: types.AssumeRole{RoleARN: "arn:aws:iam::123456789012:role/test-role", ExternalID: "extID123"},
+ want: eksClustersToResources(t, eksMockClusters...),
+ },
{
name: "list prod clusters",
args: args{
@@ -88,7 +100,6 @@ func TestEKSFetcher(t *testing.T) {
},
want: eksClustersToResources(t),
},
-
{
name: "list everything with specified values",
args: args{
@@ -102,14 +113,24 @@ func TestEKSFetcher(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ stsClt := &mocks.STSClient{}
cfg := EKSFetcherConfig{
- ClientGetter: &mockEKSClientGetter{},
+ ClientGetter: &mockEKSClientGetter{
+ AWSConfigProvider: mocks.AWSConfigProvider{
+ STSClient: stsClt,
+ },
+ },
+ AssumeRole: tt.assumeRole,
FilterLabels: tt.args.filterLabels,
Region: tt.args.region,
Logger: utils.NewSlogLoggerForTests(),
}
fetcher, err := NewEKSFetcher(cfg)
require.NoError(t, err)
+ if tt.assumeRole.RoleARN != "" {
+ require.Contains(t, stsClt.GetAssumedRoleARNs(), tt.assumeRole.RoleARN)
+ stsClt.ResetAssumeRoleHistory()
+ }
resources, err := fetcher.Get(context.Background())
require.NoError(t, err)
@@ -123,54 +144,68 @@ func TestEKSFetcher(t *testing.T) {
}
require.Equal(t, tt.want.ToMap(), clusters.ToMap())
+ if tt.assumeRole.RoleARN != "" {
+ require.Contains(t, stsClt.GetAssumedRoleARNs(), tt.assumeRole.RoleARN)
+ }
})
}
}
-type mockEKSClientGetter struct{}
+type mockEKSClientGetter struct {
+ mocks.AWSConfigProvider
+}
+
+func (e *mockEKSClientGetter) GetAWSEKSClient(cfg aws.Config) EKSClient {
+ return newPopulatedEKSMock()
+}
-func (e *mockEKSClientGetter) GetAWSEKSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (eksiface.EKSAPI, error) {
- return newPopulatedEKSMock(), nil
+func (e *mockEKSClientGetter) GetAWSSTSClient(aws.Config) STSClient {
+ return &mockSTSAPI{}
}
-func (e *mockEKSClientGetter) GetAWSSTSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (stsiface.STSAPI, error) {
- return &mockSTSAPI{}, nil
+func (e *mockEKSClientGetter) GetAWSSTSPresignClient(aws.Config) kubeutils.STSPresignClient {
+ return &mockSTSPresignAPI{}
+}
+
+type mockSTSPresignAPI struct{}
+
+func (a *mockSTSPresignAPI) PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) {
+ panic("not implemented")
}
type mockSTSAPI struct {
- stsiface.STSAPI
arn string
}
-func (a *mockSTSAPI) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdentityInput, ...request.Option) (*sts.GetCallerIdentityOutput, error) {
+func (a *mockSTSAPI) GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) {
return &sts.GetCallerIdentityOutput{
Arn: aws.String(a.arn),
}, nil
}
+func (a *mockSTSAPI) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) {
+ panic("not implemented")
+}
+
type mockEKSAPI struct {
- eksiface.EKSAPI
- clusters []*eks.Cluster
+ EKSClient
+
+ clusters []*ekstypes.Cluster
}
-func (m *mockEKSAPI) ListClustersPagesWithContext(ctx aws.Context, req *eks.ListClustersInput, f func(*eks.ListClustersOutput, bool) bool, _ ...request.Option) error {
- var names []*string
+func (m *mockEKSAPI) ListClusters(ctx context.Context, req *eks.ListClustersInput, _ ...func(*eks.Options)) (*eks.ListClustersOutput, error) {
+ var names []string
for _, cluster := range m.clusters {
- names = append(names, cluster.Name)
+ names = append(names, aws.ToString(cluster.Name))
}
- f(&eks.ListClustersOutput{
- Clusters: names[:len(names)/2],
- }, false)
-
- f(&eks.ListClustersOutput{
- Clusters: names[len(names)/2:],
- }, true)
- return nil
+ return &eks.ListClustersOutput{
+ Clusters: names,
+ }, nil
}
-func (m *mockEKSAPI) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) {
+func (m *mockEKSAPI) DescribeCluster(_ context.Context, req *eks.DescribeClusterInput, _ ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) {
for _, cluster := range m.clusters {
- if aws.StringValue(cluster.Name) == aws.StringValue(req.Name) {
+ if aws.ToString(cluster.Name) == aws.ToString(req.Name) {
return &eks.DescribeClusterOutput{
Cluster: cluster,
}, nil
@@ -185,51 +220,50 @@ func newPopulatedEKSMock() *mockEKSAPI {
}
}
-var eksMockClusters = []*eks.Cluster{
-
+var eksMockClusters = []*ekstypes.Cluster{
{
Name: aws.String("cluster1"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("prod"),
- "location": aws.String("eu-west-1"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "prod",
+ "location": "eu-west-1",
},
},
{
Name: aws.String("cluster2"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster2"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("prod"),
- "location": aws.String("eu-west-1"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "prod",
+ "location": "eu-west-1",
},
},
{
Name: aws.String("cluster3"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster3"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("stg"),
- "location": aws.String("eu-west-1"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "stg",
+ "location": "eu-west-1",
},
},
{
Name: aws.String("cluster4"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"),
- Status: aws.String(eks.ClusterStatusActive),
- Tags: map[string]*string{
- "env": aws.String("stg"),
- "location": aws.String("eu-west-1"),
+ Status: ekstypes.ClusterStatusActive,
+ Tags: map[string]string{
+ "env": "stg",
+ "location": "eu-west-1",
},
},
}
-func eksClustersToResources(t *testing.T, clusters ...*eks.Cluster) types.ResourcesWithLabels {
+func eksClustersToResources(t *testing.T, clusters ...*ekstypes.Cluster) types.ResourcesWithLabels {
var kubeClusters types.KubeClusters
for _, cluster := range clusters {
- kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.StringValue(cluster.Name), aws.StringValue(cluster.Arn), cluster.Tags)
+ kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(cluster.Name), aws.ToString(cluster.Arn), cluster.Tags)
require.NoError(t, err)
require.True(t, kubeCluster.IsAWS())
common.ApplyEKSNameSuffix(kubeCluster)
diff --git a/lib/srv/discovery/kube_integration_watcher_test.go b/lib/srv/discovery/kube_integration_watcher_test.go
index 423339678ae8d..3c7cbd57731fd 100644
--- a/lib/srv/discovery/kube_integration_watcher_test.go
+++ b/lib/srv/discovery/kube_integration_watcher_test.go
@@ -26,9 +26,8 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/eks"
- eksTypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
+ ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
- eksV1 "github.com/aws/aws-sdk-go/service/eks"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
@@ -45,7 +44,6 @@ import (
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/authz"
- "github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
"github.com/gravitational/teleport/lib/integrations/awsoidc"
"github.com/gravitational/teleport/lib/services"
@@ -56,22 +54,24 @@ import (
func TestServer_getKubeFetchers(t *testing.T) {
eks1, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{
- ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}},
+ ClientGetter: &mockFetchersClients{},
FilterLabels: types.Labels{"l1": []string{"v1"}},
Region: "region1",
})
require.NoError(t, err)
eks2, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{
- ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}},
+ ClientGetter: &mockFetchersClients{},
FilterLabels: types.Labels{"l1": []string{"v1"}},
Region: "region1",
- Integration: "aws1"})
+ Integration: "aws1",
+ })
require.NoError(t, err)
eks3, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{
- ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}},
+ ClientGetter: &mockFetchersClients{},
FilterLabels: types.Labels{"l1": []string{"v1"}},
Region: "region1",
- Integration: "aws1"})
+ Integration: "aws1",
+ })
require.NoError(t, err)
aks1, err := fetchers.NewAKSFetcher(fetchers.AKSFetcherConfig{
@@ -139,20 +139,51 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) {
testCAData = "VGVzdENBREFUQQ=="
)
- testEKSClusters := []eksTypes.Cluster{
+ // Create and start test auth server.
+ testAuthServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
+ Dir: t.TempDir(),
+ })
+ require.NoError(t, err)
+ t.Cleanup(func() { require.NoError(t, testAuthServer.Close()) })
+
+ awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{
+ Name: "integration1",
+ }, &types.AWSOIDCIntegrationSpecV1{
+ RoleARN: roleArn,
+ })
+ require.NoError(t, err)
+ testAuthServer.AuthServer.IntegrationsTokenGenerator = &mockIntegrationsTokenGenerator{
+ proxies: nil,
+ integrations: map[string]types.Integration{
+ awsOIDCIntegration.GetName(): awsOIDCIntegration,
+ },
+ }
+
+ ctx := context.Background()
+ tlsServer, err := testAuthServer.NewTestTLSServer()
+ require.NoError(t, err)
+ t.Cleanup(func() { require.NoError(t, tlsServer.Close()) })
+ _, err = tlsServer.Auth().CreateIntegration(ctx, awsOIDCIntegration)
+ require.NoError(t, err)
+
+ fakeConfigProvider := mocks.AWSConfigProvider{
+ OIDCIntegrationClient: tlsServer.Auth(),
+ }
+
+ testEKSClusters := []ekstypes.Cluster{
{
Name: aws.String("eks-cluster1"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"),
Tags: map[string]string{"env": "prod", "location": "eu-west-1"},
- CertificateAuthority: &eksTypes.Certificate{Data: aws.String(testCAData)},
- Status: eksTypes.ClusterStatusActive,
+ CertificateAuthority: &ekstypes.Certificate{Data: aws.String(testCAData)},
+ Status: ekstypes.ClusterStatusActive,
},
{
Name: aws.String("eks-cluster2"),
Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster2"),
Tags: map[string]string{"env": "prod", "location": "eu-west-1"},
- CertificateAuthority: &eksTypes.Certificate{Data: aws.String(testCAData)},
- Status: eksTypes.ClusterStatusActive,
+ CertificateAuthority: &ekstypes.Certificate{Data: aws.String(testCAData)},
+ Status: ekstypes.ClusterStatusActive,
},
}
@@ -173,7 +204,7 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) {
return dc
}
- clusterFinder := func(clusterName string) *eksTypes.Cluster {
+ clusterFinder := func(clusterName string) *ekstypes.Cluster {
for _, c := range testEKSClusters {
if aws.ToString(c.Name) == clusterName {
return &c
@@ -309,17 +340,9 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) {
}
for _, tc := range testCases {
- tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
- testCloudClients := &cloud.TestCloudClients{
- STS: &mocks.STSClientV1{},
- EKS: &mockEKSAPI{
- clusters: eksMockClusters[:2],
- },
- }
-
ctx := context.Background()
// Create and start test auth server.
testAuthServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
@@ -372,7 +395,10 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) {
discServer, err := New(
authz.ContextWithUser(ctx, identity.I),
&Config{
- CloudClients: testCloudClients,
+ AWSFetchersClients: &mockFetchersClients{
+ AWSConfigProvider: fakeConfigProvider,
+ eksClusters: eksMockClusters[:2],
+ },
ClusterFeatures: func() proto.Features { return proto.Features{} },
KubernetesClient: fake.NewSimpleClientset(),
AccessPoint: tc.accessPoint(t, tlsServer.Auth(), authClient),
@@ -391,7 +417,7 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) {
_, err := tlsServer.Auth().DiscoveryConfigs.CreateDiscoveryConfig(ctx, dc)
require.NoError(t, err)
- // Wait for the DiscoveryConfig to be added to the dynamic fetchers
+ // Wait for the DiscoveryConfig to be added to the dynamic fetchers.
require.Eventually(t, func() bool {
discServer.muDynamicKubeFetchers.RLock()
defer discServer.muDynamicKubeFetchers.RUnlock()
@@ -425,9 +451,9 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) {
}
}
-func mustConvertEKSToKubeServerV1(t *testing.T, eksCluster *eksV1.Cluster, resourceID, discoveryGroup string) types.KubeServer {
- eksCluster.Tags[types.OriginLabel] = aws.String(types.OriginCloud)
- eksCluster.Tags[types.InternalResourceIDLabel] = aws.String(resourceID)
+func mustConvertEKSToKubeServerV1(t *testing.T, eksCluster *ekstypes.Cluster, resourceID, _ string) types.KubeServer {
+ eksCluster.Tags[types.OriginLabel] = types.OriginCloud
+ eksCluster.Tags[types.InternalResourceIDLabel] = resourceID
kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(eksCluster.Name), aws.ToString(eksCluster.Arn), eksCluster.Tags)
assert.NoError(t, err)
@@ -440,13 +466,13 @@ func mustConvertEKSToKubeServerV1(t *testing.T, eksCluster *eksV1.Cluster, resou
return kubeServer
}
-func mustConvertEKSToKubeServerV2(t *testing.T, eksCluster *eksTypes.Cluster, resourceID, discoveryGroup string) types.KubeServer {
- eksTags := make(map[string]*string, len(eksCluster.Tags))
+func mustConvertEKSToKubeServerV2(t *testing.T, eksCluster *ekstypes.Cluster, resourceID, _ string) types.KubeServer {
+ eksTags := make(map[string]string, len(eksCluster.Tags))
for k, v := range eksCluster.Tags {
- eksTags[k] = aws.String(v)
+ eksTags[k] = v
}
- eksTags[types.OriginLabel] = aws.String(types.OriginCloud)
- eksTags[types.InternalResourceIDLabel] = aws.String(resourceID)
+ eksTags[types.OriginLabel] = types.OriginCloud
+ eksTags[types.InternalResourceIDLabel] = resourceID
kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(eksCluster.Name), aws.ToString(eksCluster.Arn), eksTags)
assert.NoError(t, err)
@@ -476,9 +502,8 @@ func (a *accessPointWrapper) EnrollEKSClusters(ctx context.Context, req *integra
}
type mockIntegrationsTokenGenerator struct {
- proxies []types.Server
- integrations map[string]types.Integration
- tokenCallsCount int
+ proxies []types.Server
+ integrations map[string]types.Integration
}
// GetIntegration returns the specified integration resources.
@@ -497,7 +522,6 @@ func (m *mockIntegrationsTokenGenerator) GetProxies() ([]types.Server, error) {
// GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC Integration action.
func (m *mockIntegrationsTokenGenerator) GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error) {
- m.tokenCallsCount++
return uuid.NewString(), nil
}
@@ -509,7 +533,7 @@ type mockEnrollEKSClusterClient struct {
describeCluster func(context.Context, *eks.DescribeClusterInput, ...func(*eks.Options)) (*eks.DescribeClusterOutput, error)
getCallerIdentity func(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
checkAgentAlreadyInstalled func(context.Context, genericclioptions.RESTClientGetter, *slog.Logger) (bool, error)
- installKubeAgent func(context.Context, *eksTypes.Cluster, string, string, string, genericclioptions.RESTClientGetter, *slog.Logger, awsoidc.EnrollEKSClustersRequest) error
+ installKubeAgent func(context.Context, *ekstypes.Cluster, string, string, string, genericclioptions.RESTClientGetter, *slog.Logger, awsoidc.EnrollEKSClustersRequest) error
createToken func(context.Context, types.ProvisionToken) error
presignGetCallerIdentityURL func(ctx context.Context, clusterName string) (string, error)
}
@@ -563,7 +587,7 @@ func (m *mockEnrollEKSClusterClient) CheckAgentAlreadyInstalled(ctx context.Cont
return false, nil
}
-func (m *mockEnrollEKSClusterClient) InstallKubeAgent(ctx context.Context, eksCluster *eksTypes.Cluster, proxyAddr, joinToken, resourceId string, kubeconfig genericclioptions.RESTClientGetter, log *slog.Logger, req awsoidc.EnrollEKSClustersRequest) error {
+func (m *mockEnrollEKSClusterClient) InstallKubeAgent(ctx context.Context, eksCluster *ekstypes.Cluster, proxyAddr, joinToken, resourceId string, kubeconfig genericclioptions.RESTClientGetter, log *slog.Logger, req awsoidc.EnrollEKSClustersRequest) error {
if m.installKubeAgent != nil {
return m.installKubeAgent(ctx, eksCluster, proxyAddr, joinToken, resourceId, kubeconfig, log, req)
}
diff --git a/lib/utils/cli.go b/lib/utils/cli.go
index e79c0bc2aa8f0..648cf7095352f 100644
--- a/lib/utils/cli.go
+++ b/lib/utils/cli.go
@@ -26,7 +26,6 @@ import (
"flag"
"fmt"
"io"
- stdlog "log"
"log/slog"
"os"
"runtime"
@@ -38,7 +37,6 @@ import (
"github.com/alecthomas/kingpin/v2"
"github.com/gravitational/trace"
- "github.com/sirupsen/logrus"
"golang.org/x/term"
"github.com/gravitational/teleport"
@@ -100,59 +98,18 @@ func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption)
opt(&o)
}
- logrus.StandardLogger().ReplaceHooks(make(logrus.LevelHooks))
- logrus.SetLevel(logutils.SlogLevelToLogrusLevel(level))
-
- var (
- w io.Writer
- enableColors bool
- )
- switch purpose {
- case LoggingForCLI:
- // If debug logging was asked for on the CLI, then write logs to stderr.
- // Otherwise, discard all logs.
- if level == slog.LevelDebug {
- enableColors = IsTerminal(os.Stderr)
- w = logutils.NewSharedWriter(os.Stderr)
- } else {
- w = io.Discard
- enableColors = false
- }
- case LoggingForDaemon:
- enableColors = IsTerminal(os.Stderr)
- w = logutils.NewSharedWriter(os.Stderr)
- }
-
- var (
- formatter logrus.Formatter
- handler slog.Handler
- )
- switch o.format {
- case LogFormatText, "":
- textFormatter := logutils.NewDefaultTextFormatter(enableColors)
-
- // Calling CheckAndSetDefaults enables the timestamp field to
- // be included in the output. The error returned is ignored
- // because the default formatter cannot be invalid.
- if purpose == LoggingForCLI && level == slog.LevelDebug {
- _ = textFormatter.CheckAndSetDefaults()
- }
-
- formatter = textFormatter
- handler = logutils.NewSlogTextHandler(w, logutils.SlogTextHandlerConfig{
- Level: level,
- EnableColors: enableColors,
- })
- case LogFormatJSON:
- formatter = &logutils.JSONFormatter{}
- handler = logutils.NewSlogJSONHandler(w, logutils.SlogJSONHandlerConfig{
- Level: level,
- })
+ // If debug or trace logging is not enabled for CLIs,
+ // then discard all log output.
+ if purpose == LoggingForCLI && level > slog.LevelDebug {
+ slog.SetDefault(slog.New(logutils.DiscardHandler{}))
+ return
}
- logrus.SetFormatter(formatter)
- logrus.SetOutput(w)
- slog.SetDefault(slog.New(handler))
+ logutils.Initialize(logutils.Config{
+ Severity: level.String(),
+ Format: o.format,
+ EnableColors: IsTerminal(os.Stderr),
+ })
}
var initTestLoggerOnce = sync.Once{}
@@ -163,56 +120,24 @@ func InitLoggerForTests() {
// Parse flags to check testing.Verbose().
flag.Parse()
- level := slog.LevelWarn
- w := io.Discard
- if testing.Verbose() {
- level = slog.LevelDebug
- w = os.Stderr
+ if !testing.Verbose() {
+ slog.SetDefault(slog.New(logutils.DiscardHandler{}))
+ return
}
- logger := logrus.StandardLogger()
- logger.SetFormatter(logutils.NewTestJSONFormatter())
- logger.SetLevel(logutils.SlogLevelToLogrusLevel(level))
-
- output := logutils.NewSharedWriter(w)
- logger.SetOutput(output)
- slog.SetDefault(slog.New(logutils.NewSlogJSONHandler(output, logutils.SlogJSONHandlerConfig{Level: level})))
+ logutils.Initialize(logutils.Config{
+ Severity: slog.LevelDebug.String(),
+ Format: LogFormatJSON,
+ })
})
}
-// NewLoggerForTests creates a new logrus logger for test environments.
-func NewLoggerForTests() *logrus.Logger {
- InitLoggerForTests()
- return logrus.StandardLogger()
-}
-
// NewSlogLoggerForTests creates a new slog logger for test environments.
func NewSlogLoggerForTests() *slog.Logger {
InitLoggerForTests()
return slog.Default()
}
-// WrapLogger wraps an existing logger entry and returns
-// a value satisfying the Logger interface
-func WrapLogger(logger *logrus.Entry) Logger {
- return &logWrapper{Entry: logger}
-}
-
-// NewLogger creates a new empty logrus logger.
-func NewLogger() *logrus.Logger {
- return logrus.StandardLogger()
-}
-
-// Logger describes a logger value
-type Logger interface {
- logrus.FieldLogger
- // GetLevel specifies the level at which this logger
- // value is logging
- GetLevel() logrus.Level
- // SetLevel sets the logger's level to the specified value
- SetLevel(level logrus.Level)
-}
-
// FatalError is for CLI front-ends: it detects gravitational/trace debugging
// information, sends it to the logger, strips it off and prints a clean message to stderr
func FatalError(err error) {
@@ -231,7 +156,7 @@ func GetIterations() int {
if err != nil {
panic(err)
}
- logrus.Debugf("Starting tests with %v iterations.", iter)
+ slog.DebugContext(context.Background(), "Running tests multiple times due to presence of ITERATIONS environment variable", "iterations", iter)
return iter
}
@@ -484,47 +409,6 @@ func AllowWhitespace(s string) string {
return sb.String()
}
-// NewStdlogger creates a new stdlib logger that uses the specified leveled logger
-// for output and the given component as a logging prefix.
-func NewStdlogger(logger LeveledOutputFunc, component string) *stdlog.Logger {
- return stdlog.New(&stdlogAdapter{
- log: logger,
- }, component, stdlog.LstdFlags)
-}
-
-// Write writes the specified buffer p to the underlying leveled logger.
-// Implements io.Writer
-func (r *stdlogAdapter) Write(p []byte) (n int, err error) {
- r.log(string(p))
- return len(p), nil
-}
-
-// stdlogAdapter is an io.Writer that writes into an instance
-// of logrus.Logger
-type stdlogAdapter struct {
- log LeveledOutputFunc
-}
-
-// LeveledOutputFunc describes a function that emits given
-// arguments at a specific level to an underlying logger
-type LeveledOutputFunc func(args ...interface{})
-
-// GetLevel returns the level of the underlying logger
-func (r *logWrapper) GetLevel() logrus.Level {
- return r.Entry.Logger.GetLevel()
-}
-
-// SetLevel sets the logging level to the given value
-func (r *logWrapper) SetLevel(level logrus.Level) {
- r.Entry.Logger.SetLevel(level)
-}
-
-// logWrapper wraps a log entry.
-// Implements Logger
-type logWrapper struct {
- *logrus.Entry
-}
-
// needsQuoting returns true if any non-printable characters are found.
func needsQuoting(text string) bool {
for _, r := range text {
diff --git a/lib/utils/log/formatter_test.go b/lib/utils/log/formatter_test.go
index 9abb0310ba0be..aff0ec8be3a74 100644
--- a/lib/utils/log/formatter_test.go
+++ b/lib/utils/log/formatter_test.go
@@ -22,7 +22,6 @@ import (
"bytes"
"context"
"encoding/json"
- "errors"
"fmt"
"io"
"log/slog"
@@ -38,7 +37,6 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
- "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -48,7 +46,7 @@ import (
const message = "Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr."
var (
- logErr = errors.New("the quick brown fox jumped really high")
+ logErr = &trace.BadParameterError{Message: "the quick brown fox jumped really high"}
addr = fakeAddr{addr: "127.0.0.1:1234"}
fields = map[string]any{
@@ -72,6 +70,10 @@ func (a fakeAddr) String() string {
return a.addr
}
+func (a fakeAddr) MarshalText() (text []byte, err error) {
+ return []byte(a.addr), nil
+}
+
func TestOutput(t *testing.T) {
loc, err := time.LoadLocation("Africa/Cairo")
require.NoError(t, err, "failed getting timezone")
@@ -89,58 +91,50 @@ func TestOutput(t *testing.T) {
// 4) the caller
outputRegex := regexp.MustCompile(`(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z)(\s+.*)(".*diag_addr\.")(.*)(\slog/formatter_test.go:\d{3})`)
+ expectedFields := map[string]string{
+ "local": addr.String(),
+ "remote": addr.String(),
+ "login": "llama",
+ "teleportUser": "user",
+ "id": "1234",
+ "test": "123",
+ "animal": `"llama\n"`,
+ "error": "[" + trace.DebugReport(logErr) + "]",
+ "diag_addr": addr.String(),
+ }
+
tests := []struct {
- name string
- logrusLevel logrus.Level
- slogLevel slog.Level
+ name string
+ slogLevel slog.Level
}{
{
- name: "trace",
- logrusLevel: logrus.TraceLevel,
- slogLevel: TraceLevel,
+ name: "trace",
+ slogLevel: TraceLevel,
},
{
- name: "debug",
- logrusLevel: logrus.DebugLevel,
- slogLevel: slog.LevelDebug,
+ name: "debug",
+ slogLevel: slog.LevelDebug,
},
{
- name: "info",
- logrusLevel: logrus.InfoLevel,
- slogLevel: slog.LevelInfo,
+ name: "info",
+ slogLevel: slog.LevelInfo,
},
{
- name: "warn",
- logrusLevel: logrus.WarnLevel,
- slogLevel: slog.LevelWarn,
+ name: "warn",
+ slogLevel: slog.LevelWarn,
},
{
- name: "error",
- logrusLevel: logrus.ErrorLevel,
- slogLevel: slog.LevelError,
+ name: "error",
+ slogLevel: slog.LevelError,
},
{
- name: "fatal",
- logrusLevel: logrus.FatalLevel,
- slogLevel: slog.LevelError + 1,
+ name: "fatal",
+ slogLevel: slog.LevelError + 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- // Create a logrus logger using the custom formatter which outputs to a local buffer.
- var logrusOutput bytes.Buffer
- formatter := NewDefaultTextFormatter(true)
- formatter.timestampEnabled = true
- require.NoError(t, formatter.CheckAndSetDefaults())
-
- logrusLogger := logrus.New()
- logrusLogger.SetFormatter(formatter)
- logrusLogger.SetOutput(&logrusOutput)
- logrusLogger.ReplaceHooks(logrus.LevelHooks{})
- logrusLogger.SetLevel(test.logrusLevel)
- entry := logrusLogger.WithField(teleport.ComponentKey, "test").WithTime(clock.Now().UTC())
-
// Create a slog logger using the custom handler which outputs to a local buffer.
var slogOutput bytes.Buffer
slogConfig := SlogTextHandlerConfig{
@@ -155,13 +149,6 @@ func TestOutput(t *testing.T) {
}
slogLogger := slog.New(NewSlogTextHandler(&slogOutput, slogConfig)).With(teleport.ComponentKey, "test")
- // Add some fields and output the message at the desired log level via logrus.
- l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr)
- logrusTestLogLineNumber := func() int {
- l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Log(test.logrusLevel, message)
- return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it
- }()
-
// Add some fields and output the message at the desired log level via slog.
l2 := slogLogger.With("test", 123).With("animal", "llama\n").With("error", logErr)
slogTestLogLineNumber := func() int {
@@ -169,163 +156,144 @@ func TestOutput(t *testing.T) {
return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it
}()
- // Validate that both loggers produces the same output. The added complexity comes from the fact that
- // our custom slog handler does NOT sort the additional fields like our logrus formatter does.
- logrusMatches := outputRegex.FindStringSubmatch(logrusOutput.String())
- require.NotEmpty(t, logrusMatches, "logrus output was in unexpected format: %s", logrusOutput.String())
+ // Validate the logger output. The added complexity comes from the fact that
+ // our custom slog handler does NOT sort the additional fields.
slogMatches := outputRegex.FindStringSubmatch(slogOutput.String())
require.NotEmpty(t, slogMatches, "slog output was in unexpected format: %s", slogOutput.String())
// The first match is the timestamp: 2023-10-31T10:09:06+02:00
- logrusTime, err := time.Parse(time.RFC3339, logrusMatches[1])
- assert.NoError(t, err, "invalid logrus timestamp found %s", logrusMatches[1])
-
slogTime, err := time.Parse(time.RFC3339, slogMatches[1])
assert.NoError(t, err, "invalid slog timestamp found %s", slogMatches[1])
-
- assert.InDelta(t, logrusTime.Unix(), slogTime.Unix(), 10)
+ assert.InDelta(t, clock.Now().Unix(), slogTime.Unix(), 10)
// Match level, and component: DEBU [TEST]
- assert.Empty(t, cmp.Diff(logrusMatches[2], slogMatches[2]), "level, and component to be identical")
- // Match the log message: "Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr.\n"
- assert.Empty(t, cmp.Diff(logrusMatches[3], slogMatches[3]), "expected output messages to be identical")
+ expectedLevel := formatLevel(test.slogLevel, true)
+ expectedComponent := formatComponent(slog.StringValue("test"), defaultComponentPadding)
+ expectedMatch := " " + expectedLevel + " " + expectedComponent + " "
+ assert.Equal(t, expectedMatch, slogMatches[2], "level, and component to be identical")
+ // Match the log message
+ assert.Equal(t, `"Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr."`, slogMatches[3], "expected output messages to be identical")
// The last matches are the caller information
- assert.Equal(t, fmt.Sprintf(" log/formatter_test.go:%d", logrusTestLogLineNumber), logrusMatches[5])
assert.Equal(t, fmt.Sprintf(" log/formatter_test.go:%d", slogTestLogLineNumber), slogMatches[5])
// The third matches are the fields which will be key value pairs(animal:llama) separated by a space. Since
- // logrus sorts the fields and slog doesn't we can't just assert equality and instead build a map of the key
+ // slog doesn't sort the fields, we can't assert equality and instead build a map of the key
// value pairs to ensure they are all present and accounted for.
- logrusFieldMatches := fieldsRegex.FindAllStringSubmatch(logrusMatches[4], -1)
slogFieldMatches := fieldsRegex.FindAllStringSubmatch(slogMatches[4], -1)
// The first match is the key, the second match is the value
- logrusFields := map[string]string{}
- for _, match := range logrusFieldMatches {
- logrusFields[strings.TrimSpace(match[1])] = strings.TrimSpace(match[2])
- }
-
slogFields := map[string]string{}
for _, match := range slogFieldMatches {
slogFields[strings.TrimSpace(match[1])] = strings.TrimSpace(match[2])
}
- assert.Equal(t, slogFields, logrusFields)
+ require.Empty(t,
+ cmp.Diff(
+ expectedFields,
+ slogFields,
+ cmpopts.SortMaps(func(a, b string) bool { return a < b }),
+ ),
+ )
})
}
})
t.Run("json", func(t *testing.T) {
tests := []struct {
- name string
- logrusLevel logrus.Level
- slogLevel slog.Level
+ name string
+ slogLevel slog.Level
}{
{
- name: "trace",
- logrusLevel: logrus.TraceLevel,
- slogLevel: TraceLevel,
+ name: "trace",
+ slogLevel: TraceLevel,
},
{
- name: "debug",
- logrusLevel: logrus.DebugLevel,
- slogLevel: slog.LevelDebug,
+ name: "debug",
+ slogLevel: slog.LevelDebug,
},
{
- name: "info",
- logrusLevel: logrus.InfoLevel,
- slogLevel: slog.LevelInfo,
+ name: "info",
+ slogLevel: slog.LevelInfo,
},
{
- name: "warn",
- logrusLevel: logrus.WarnLevel,
- slogLevel: slog.LevelWarn,
+ name: "warn",
+ slogLevel: slog.LevelWarn,
},
{
- name: "error",
- logrusLevel: logrus.ErrorLevel,
- slogLevel: slog.LevelError,
+ name: "error",
+ slogLevel: slog.LevelError,
},
{
- name: "fatal",
- logrusLevel: logrus.FatalLevel,
- slogLevel: slog.LevelError + 1,
+ name: "fatal",
+ slogLevel: slog.LevelError + 1,
+ },
+ }
+
+ expectedFields := map[string]any{
+ "trace.fields": map[string]any{
+ "teleportUser": "user",
+ "id": float64(1234),
+ "local": addr.String(),
+ "login": "llama",
+ "remote": addr.String(),
},
+ "test": float64(123),
+ "animal": `llama`,
+ "error": logErr.Error(),
+ "diag_addr": addr.String(),
+ "component": "test",
+ "message": message,
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- // Create a logrus logger using the custom formatter which outputs to a local buffer.
- var logrusOut bytes.Buffer
- formatter := &JSONFormatter{
- ExtraFields: nil,
- callerEnabled: true,
- }
- require.NoError(t, formatter.CheckAndSetDefaults())
-
- logrusLogger := logrus.New()
- logrusLogger.SetFormatter(formatter)
- logrusLogger.SetOutput(&logrusOut)
- logrusLogger.ReplaceHooks(logrus.LevelHooks{})
- logrusLogger.SetLevel(test.logrusLevel)
- entry := logrusLogger.WithField(teleport.ComponentKey, "test")
-
// Create a slog logger using the custom formatter which outputs to a local buffer.
var slogOutput bytes.Buffer
slogLogger := slog.New(NewSlogJSONHandler(&slogOutput, SlogJSONHandlerConfig{Level: test.slogLevel})).With(teleport.ComponentKey, "test")
- // Add some fields and output the message at the desired log level via logrus.
- l := entry.WithField("test", 123).WithField("animal", "llama").WithField("error", trace.Wrap(logErr))
- logrusTestLogLineNumber := func() int {
- l.WithField("diag_addr", addr.String()).Log(test.logrusLevel, message)
- return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it
- }()
-
// Add some fields and output the message at the desired log level via slog.
l2 := slogLogger.With("test", 123).With("animal", "llama").With("error", trace.Wrap(logErr))
slogTestLogLineNumber := func() int {
- l2.Log(context.Background(), test.slogLevel, message, "diag_addr", &addr)
+ l2.With(teleport.ComponentFields, fields).Log(context.Background(), test.slogLevel, message, "diag_addr", &addr)
return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it
}()
- // The order of the fields emitted by the two loggers is different, so comparing the output directly
- // for equality won't work. Instead, a map is built with all the key value pairs, excluding the caller
- // and that map is compared to ensure all items are present and match.
- var logrusData map[string]any
- require.NoError(t, json.Unmarshal(logrusOut.Bytes(), &logrusData), "invalid logrus output format")
-
var slogData map[string]any
require.NoError(t, json.Unmarshal(slogOutput.Bytes(), &slogData), "invalid slog output format")
- logrusCaller, ok := logrusData["caller"].(string)
- delete(logrusData, "caller")
- assert.True(t, ok, "caller was missing from logrus output")
- assert.Equal(t, fmt.Sprintf("log/formatter_test.go:%d", logrusTestLogLineNumber), logrusCaller)
-
slogCaller, ok := slogData["caller"].(string)
delete(slogData, "caller")
assert.True(t, ok, "caller was missing from slog output")
assert.Equal(t, fmt.Sprintf("log/formatter_test.go:%d", slogTestLogLineNumber), slogCaller)
- logrusTimestamp, ok := logrusData["timestamp"].(string)
- delete(logrusData, "timestamp")
- assert.True(t, ok, "time was missing from logrus output")
+ slogLevel, ok := slogData["level"].(string)
+ delete(slogData, "level")
+ assert.True(t, ok, "level was missing from slog output")
+ var expectedLevel string
+ switch test.slogLevel {
+ case TraceLevel:
+ expectedLevel = "trace"
+ case slog.LevelWarn:
+ expectedLevel = "warning"
+ case slog.LevelError + 1:
+ expectedLevel = "fatal"
+ default:
+ expectedLevel = test.slogLevel.String()
+ }
+ assert.Equal(t, strings.ToLower(expectedLevel), slogLevel)
slogTimestamp, ok := slogData["timestamp"].(string)
delete(slogData, "timestamp")
assert.True(t, ok, "time was missing from slog output")
- logrusTime, err := time.Parse(time.RFC3339, logrusTimestamp)
- assert.NoError(t, err, "invalid logrus timestamp %s", logrusTimestamp)
-
slogTime, err := time.Parse(time.RFC3339, slogTimestamp)
assert.NoError(t, err, "invalid slog timestamp %s", slogTimestamp)
- assert.InDelta(t, logrusTime.Unix(), slogTime.Unix(), 10)
+ assert.InDelta(t, clock.Now().Unix(), slogTime.Unix(), 10)
require.Empty(t,
cmp.Diff(
- logrusData,
+ expectedFields,
slogData,
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
),
@@ -347,38 +315,6 @@ func getCallerLineNumber() int {
func BenchmarkFormatter(b *testing.B) {
ctx := context.Background()
b.ReportAllocs()
- b.Run("logrus", func(b *testing.B) {
- b.Run("text", func(b *testing.B) {
- formatter := NewDefaultTextFormatter(true)
- require.NoError(b, formatter.CheckAndSetDefaults())
- logger := logrus.New()
- logger.SetFormatter(formatter)
- logger.SetOutput(io.Discard)
- b.ResetTimer()
-
- entry := logger.WithField(teleport.ComponentKey, "test")
- for i := 0; i < b.N; i++ {
- l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr)
- l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Info(message)
- }
- })
-
- b.Run("json", func(b *testing.B) {
- formatter := &JSONFormatter{}
- require.NoError(b, formatter.CheckAndSetDefaults())
- logger := logrus.New()
- logger.SetFormatter(formatter)
- logger.SetOutput(io.Discard)
- logger.ReplaceHooks(logrus.LevelHooks{})
- b.ResetTimer()
-
- entry := logger.WithField(teleport.ComponentKey, "test")
- for i := 0; i < b.N; i++ {
- l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr)
- l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Info(message)
- }
- })
- })
b.Run("slog", func(b *testing.B) {
b.Run("default_text", func(b *testing.B) {
@@ -430,47 +366,26 @@ func BenchmarkFormatter(b *testing.B) {
}
func TestConcurrentOutput(t *testing.T) {
- t.Run("logrus", func(t *testing.T) {
- debugFormatter := NewDefaultTextFormatter(true)
- require.NoError(t, debugFormatter.CheckAndSetDefaults())
- logrus.SetFormatter(debugFormatter)
- logrus.SetOutput(os.Stdout)
-
- logger := logrus.WithField(teleport.ComponentKey, "test")
-
- var wg sync.WaitGroup
- for i := 0; i < 1000; i++ {
- wg.Add(1)
- go func(i int) {
- defer wg.Done()
- logger.Infof("Detected Teleport component %d is running in a degraded state.", i)
- }(i)
- }
- wg.Wait()
- })
+ logger := slog.New(NewSlogTextHandler(os.Stdout, SlogTextHandlerConfig{
+ EnableColors: true,
+ })).With(teleport.ComponentKey, "test")
- t.Run("slog", func(t *testing.T) {
- logger := slog.New(NewSlogTextHandler(os.Stdout, SlogTextHandlerConfig{
- EnableColors: true,
- })).With(teleport.ComponentKey, "test")
-
- var wg sync.WaitGroup
- ctx := context.Background()
- for i := 0; i < 1000; i++ {
- wg.Add(1)
- go func(i int) {
- defer wg.Done()
- logger.InfoContext(ctx, "Teleport component entered degraded state",
- slog.Int("component", i),
- slog.Group("group",
- slog.String("test", "123"),
- slog.String("animal", "llama"),
- ),
- )
- }(i)
- }
- wg.Wait()
- })
+ var wg sync.WaitGroup
+ ctx := context.Background()
+ for i := 0; i < 1000; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ logger.InfoContext(ctx, "Teleport component entered degraded state",
+ slog.Int("component", i),
+ slog.Group("group",
+ slog.String("test", "123"),
+ slog.String("animal", "llama"),
+ ),
+ )
+ }(i)
+ }
+ wg.Wait()
}
// allPossibleSubsets returns all combinations of subsets for the
@@ -493,58 +408,34 @@ func allPossibleSubsets(in []string) [][]string {
return subsets
}
-// TestExtraFields validates that the output is identical for the
-// logrus formatter and slog handler based on the configured extra
-// fields.
+// TestExtraFields validates that the output is expected for the
+// slog handler based on the configured extra fields.
func TestExtraFields(t *testing.T) {
// Capture a fake time that all output will use.
now := clockwork.NewFakeClock().Now()
// Capture the caller information to be injected into all messages.
pc, _, _, _ := runtime.Caller(0)
- fs := runtime.CallersFrames([]uintptr{pc})
- f, _ := fs.Next()
- callerTrace := &trace.Trace{
- Func: f.Function,
- Path: f.File,
- Line: f.Line,
- }
const message = "testing 123"
- // Test against every possible configured combination of allowed format fields.
- fields := allPossibleSubsets(defaultFormatFields)
-
t.Run("text", func(t *testing.T) {
- for _, configuredFields := range fields {
+ // Test against every possible configured combination of allowed format fields.
+ for _, configuredFields := range allPossibleSubsets(defaultFormatFields) {
name := "not configured"
if len(configuredFields) > 0 {
name = strings.Join(configuredFields, " ")
}
t.Run(name, func(t *testing.T) {
- logrusFormatter := TextFormatter{
- ExtraFields: configuredFields,
- }
- // Call CheckAndSetDefaults to exercise the extra fields logic. Since
- // FormatCaller is always overridden within CheckAndSetDefaults, it is
- // explicitly set afterward so the caller points to our fake call site.
- require.NoError(t, logrusFormatter.CheckAndSetDefaults())
- logrusFormatter.FormatCaller = callerTrace.String
-
- var slogOutput bytes.Buffer
- var slogHandler slog.Handler = NewSlogTextHandler(&slogOutput, SlogTextHandlerConfig{ConfiguredFields: configuredFields})
-
- entry := &logrus.Entry{
- Data: logrus.Fields{"animal": "llama", "vegetable": "carrot", teleport.ComponentKey: "test"},
- Time: now,
- Level: logrus.DebugLevel,
- Caller: &f,
- Message: message,
- }
-
- logrusOut, err := logrusFormatter.Format(entry)
- require.NoError(t, err)
+ replaced := map[string]struct{}{}
+ var slogHandler slog.Handler = NewSlogTextHandler(io.Discard, SlogTextHandlerConfig{
+ ConfiguredFields: configuredFields,
+ ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
+ replaced[a.Key] = struct{}{}
+ return a
+ },
+ })
record := slog.Record{
Time: now,
@@ -557,42 +448,29 @@ func TestExtraFields(t *testing.T) {
require.NoError(t, slogHandler.Handle(context.Background(), record))
- require.Equal(t, string(logrusOut), slogOutput.String())
+ for k := range replaced {
+ delete(replaced, k)
+ }
+
+ require.Empty(t, replaced, replaced)
})
}
})
t.Run("json", func(t *testing.T) {
- for _, configuredFields := range fields {
+ // Test against every possible configured combination of allowed format fields.
+ // Note, the json handler limits the allowed fields to a subset of those allowed
+ // by the text handler.
+ for _, configuredFields := range allPossibleSubsets([]string{CallerField, ComponentField, TimestampField}) {
name := "not configured"
if len(configuredFields) > 0 {
name = strings.Join(configuredFields, " ")
}
t.Run(name, func(t *testing.T) {
- logrusFormatter := JSONFormatter{
- ExtraFields: configuredFields,
- }
- // Call CheckAndSetDefaults to exercise the extra fields logic. Since
- // FormatCaller is always overridden within CheckAndSetDefaults, it is
- // explicitly set afterward so the caller points to our fake call site.
- require.NoError(t, logrusFormatter.CheckAndSetDefaults())
- logrusFormatter.FormatCaller = callerTrace.String
-
var slogOutput bytes.Buffer
var slogHandler slog.Handler = NewSlogJSONHandler(&slogOutput, SlogJSONHandlerConfig{ConfiguredFields: configuredFields})
- entry := &logrus.Entry{
- Data: logrus.Fields{"animal": "llama", "vegetable": "carrot", teleport.ComponentKey: "test"},
- Time: now,
- Level: logrus.DebugLevel,
- Caller: &f,
- Message: message,
- }
-
- logrusOut, err := logrusFormatter.Format(entry)
- require.NoError(t, err)
-
record := slog.Record{
Time: now,
Message: message,
@@ -604,11 +482,31 @@ func TestExtraFields(t *testing.T) {
require.NoError(t, slogHandler.Handle(context.Background(), record))
- var slogData, logrusData map[string]any
- require.NoError(t, json.Unmarshal(logrusOut, &logrusData))
+ var slogData map[string]any
require.NoError(t, json.Unmarshal(slogOutput.Bytes(), &slogData))
- require.Equal(t, slogData, logrusData)
+ delete(slogData, "animal")
+ delete(slogData, "vegetable")
+ delete(slogData, "message")
+ delete(slogData, "level")
+
+ var expectedLen int
+ expectedFields := configuredFields
+ switch l := len(configuredFields); l {
+ case 0:
+ // The level field was removed above, but is included in the default fields
+ expectedLen = len(defaultFormatFields) - 1
+ expectedFields = defaultFormatFields
+ default:
+ expectedLen = l
+ }
+ require.Len(t, slogData, expectedLen, slogData)
+
+ for _, f := range expectedFields {
+ delete(slogData, f)
+ }
+
+ require.Empty(t, slogData, slogData)
})
}
})
diff --git a/lib/utils/log/log.go b/lib/utils/log/log.go
index 2f16b902e3df6..d8aadb75146bf 100644
--- a/lib/utils/log/log.go
+++ b/lib/utils/log/log.go
@@ -42,6 +42,8 @@ type Config struct {
ExtraFields []string
// EnableColors dictates if output should be colored.
EnableColors bool
+ // Padding to use for various components.
+ Padding int
}
// Initialize configures the default global logger based on the
@@ -112,6 +114,7 @@ func Initialize(loggerConfig Config) (*slog.Logger, *slog.LevelVar, error) {
Level: level,
EnableColors: loggerConfig.EnableColors,
ConfiguredFields: configuredFields,
+ Padding: loggerConfig.Padding,
}))
slog.SetDefault(logger)
case "json":
diff --git a/lib/utils/log/logrus_formatter.go b/lib/utils/log/logrus_formatter.go
deleted file mode 100644
index 14ad8441da7cc..0000000000000
--- a/lib/utils/log/logrus_formatter.go
+++ /dev/null
@@ -1,427 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package log
-
-import (
- "fmt"
- "regexp"
- "runtime"
- "slices"
- "strconv"
- "strings"
-
- "github.com/gravitational/trace"
- "github.com/sirupsen/logrus"
-
- "github.com/gravitational/teleport"
-)
-
-// TextFormatter is a [logrus.Formatter] that outputs messages in
-// a textual format.
-type TextFormatter struct {
- // ComponentPadding is a padding to pick when displaying
- // and formatting component field, defaults to DefaultComponentPadding
- ComponentPadding int
- // EnableColors enables colored output
- EnableColors bool
- // FormatCaller is a function to return (part) of source file path for output.
- // Defaults to filePathAndLine() if unspecified
- FormatCaller func() (caller string)
- // ExtraFields represent the extra fields that will be added to the log message
- ExtraFields []string
- // TimestampEnabled specifies if timestamp is enabled in logs
- timestampEnabled bool
- // CallerEnabled specifies if caller is enabled in logs
- callerEnabled bool
-}
-
-type writer struct {
- b *buffer
-}
-
-func newWriter() *writer {
- return &writer{b: &buffer{}}
-}
-
-func (w *writer) Len() int {
- return len(*w.b)
-}
-
-func (w *writer) WriteString(s string) (int, error) {
- return w.b.WriteString(s)
-}
-
-func (w *writer) WriteByte(c byte) error {
- return w.b.WriteByte(c)
-}
-
-func (w *writer) Bytes() []byte {
- return *w.b
-}
-
-// NewDefaultTextFormatter creates a TextFormatter with
-// the default options set.
-func NewDefaultTextFormatter(enableColors bool) *TextFormatter {
- return &TextFormatter{
- ComponentPadding: defaultComponentPadding,
- FormatCaller: formatCallerWithPathAndLine,
- ExtraFields: defaultFormatFields,
- EnableColors: enableColors,
- callerEnabled: true,
- timestampEnabled: false,
- }
-}
-
-// CheckAndSetDefaults checks and sets log format configuration.
-func (tf *TextFormatter) CheckAndSetDefaults() error {
- // set padding
- if tf.ComponentPadding == 0 {
- tf.ComponentPadding = defaultComponentPadding
- }
- // set caller
- tf.FormatCaller = formatCallerWithPathAndLine
-
- // set log formatting
- if tf.ExtraFields == nil {
- tf.timestampEnabled = true
- tf.callerEnabled = true
- tf.ExtraFields = defaultFormatFields
- return nil
- }
-
- if slices.Contains(tf.ExtraFields, TimestampField) {
- tf.timestampEnabled = true
- }
-
- if slices.Contains(tf.ExtraFields, CallerField) {
- tf.callerEnabled = true
- }
-
- return nil
-}
-
-// Format formats each log line as configured in teleport config file.
-func (tf *TextFormatter) Format(e *logrus.Entry) ([]byte, error) {
- caller := tf.FormatCaller()
- w := newWriter()
-
- // write timestamp first if enabled
- if tf.timestampEnabled {
- *w.b = appendRFC3339Millis(*w.b, e.Time.Round(0))
- }
-
- for _, field := range tf.ExtraFields {
- switch field {
- case LevelField:
- var color int
- var level string
- switch e.Level {
- case logrus.TraceLevel:
- level = "TRACE"
- color = gray
- case logrus.DebugLevel:
- level = "DEBUG"
- color = gray
- case logrus.InfoLevel:
- level = "INFO"
- color = blue
- case logrus.WarnLevel:
- level = "WARN"
- color = yellow
- case logrus.ErrorLevel:
- level = "ERROR"
- color = red
- case logrus.FatalLevel:
- level = "FATAL"
- color = red
- default:
- color = blue
- level = strings.ToUpper(e.Level.String())
- }
-
- if !tf.EnableColors {
- color = noColor
- }
-
- w.writeField(padMax(level, defaultLevelPadding), color)
- case ComponentField:
- padding := defaultComponentPadding
- if tf.ComponentPadding != 0 {
- padding = tf.ComponentPadding
- }
- if w.Len() > 0 {
- w.WriteByte(' ')
- }
- component, ok := e.Data[teleport.ComponentKey].(string)
- if ok && component != "" {
- component = fmt.Sprintf("[%v]", component)
- }
- component = strings.ToUpper(padMax(component, padding))
- if component[len(component)-1] != ' ' {
- component = component[:len(component)-1] + "]"
- }
-
- w.WriteString(component)
- default:
- if _, ok := knownFormatFields[field]; !ok {
- return nil, trace.BadParameter("invalid log format key: %v", field)
- }
- }
- }
-
- // always use message
- if e.Message != "" {
- w.writeField(e.Message, noColor)
- }
-
- if len(e.Data) > 0 {
- w.writeMap(e.Data)
- }
-
- // write caller last if enabled
- if tf.callerEnabled && caller != "" {
- w.writeField(caller, noColor)
- }
-
- w.WriteByte('\n')
- return w.Bytes(), nil
-}
-
-// JSONFormatter implements the [logrus.Formatter] interface and adds extra
-// fields to log entries.
-type JSONFormatter struct {
- logrus.JSONFormatter
-
- ExtraFields []string
- // FormatCaller is a function to return (part) of source file path for output.
- // Defaults to filePathAndLine() if unspecified
- FormatCaller func() (caller string)
-
- callerEnabled bool
- componentEnabled bool
-}
-
-// CheckAndSetDefaults checks and sets log format configuration.
-func (j *JSONFormatter) CheckAndSetDefaults() error {
- // set log formatting
- if j.ExtraFields == nil {
- j.ExtraFields = defaultFormatFields
- }
- // set caller
- j.FormatCaller = formatCallerWithPathAndLine
-
- if slices.Contains(j.ExtraFields, CallerField) {
- j.callerEnabled = true
- }
-
- if slices.Contains(j.ExtraFields, ComponentField) {
- j.componentEnabled = true
- }
-
- // rename default fields
- j.JSONFormatter = logrus.JSONFormatter{
- FieldMap: logrus.FieldMap{
- logrus.FieldKeyTime: TimestampField,
- logrus.FieldKeyLevel: LevelField,
- logrus.FieldKeyMsg: messageField,
- },
- DisableTimestamp: !slices.Contains(j.ExtraFields, TimestampField),
- }
-
- return nil
-}
-
-// Format formats each log line as configured in teleport config file.
-func (j *JSONFormatter) Format(e *logrus.Entry) ([]byte, error) {
- if j.callerEnabled {
- path := j.FormatCaller()
- e.Data[CallerField] = path
- }
-
- if j.componentEnabled {
- e.Data[ComponentField] = e.Data[teleport.ComponentKey]
- }
-
- delete(e.Data, teleport.ComponentKey)
-
- return j.JSONFormatter.Format(e)
-}
-
-// NewTestJSONFormatter creates a JSONFormatter that is
-// configured for output in tests.
-func NewTestJSONFormatter() *JSONFormatter {
- formatter := &JSONFormatter{}
- if err := formatter.CheckAndSetDefaults(); err != nil {
- panic(err)
- }
- return formatter
-}
-
-func (w *writer) writeError(value interface{}) {
- switch err := value.(type) {
- case trace.Error:
- *w.b = fmt.Appendf(*w.b, "[%v]", err.DebugReport())
- default:
- *w.b = fmt.Appendf(*w.b, "[%v]", value)
- }
-}
-
-func (w *writer) writeField(value interface{}, color int) {
- if w.Len() > 0 {
- w.WriteByte(' ')
- }
- w.writeValue(value, color)
-}
-
-func (w *writer) writeKeyValue(key string, value interface{}) {
- if w.Len() > 0 {
- w.WriteByte(' ')
- }
- w.WriteString(key)
- w.WriteByte(':')
- if key == logrus.ErrorKey {
- w.writeError(value)
- return
- }
- w.writeValue(value, noColor)
-}
-
-func (w *writer) writeValue(value interface{}, color int) {
- if s, ok := value.(string); ok {
- if color != noColor {
- *w.b = fmt.Appendf(*w.b, "\u001B[%dm", color)
- }
-
- if needsQuoting(s) {
- *w.b = strconv.AppendQuote(*w.b, s)
- } else {
- *w.b = fmt.Append(*w.b, s)
- }
-
- if color != noColor {
- *w.b = fmt.Append(*w.b, "\u001B[0m")
- }
- return
- }
-
- if color != noColor {
- *w.b = fmt.Appendf(*w.b, "\x1b[%dm%v\x1b[0m", color, value)
- return
- }
-
- *w.b = fmt.Appendf(*w.b, "%v", value)
-}
-
-func (w *writer) writeMap(m map[string]any) {
- if len(m) == 0 {
- return
- }
- keys := make([]string, 0, len(m))
- for key := range m {
- keys = append(keys, key)
- }
- slices.Sort(keys)
- for _, key := range keys {
- if key == teleport.ComponentKey {
- continue
- }
- switch value := m[key].(type) {
- case map[string]any:
- w.writeMap(value)
- case logrus.Fields:
- w.writeMap(value)
- default:
- w.writeKeyValue(key, value)
- }
- }
-}
-
-type frameCursor struct {
- // current specifies the current stack frame.
- // if omitted, rest contains the complete stack
- current *runtime.Frame
- // rest specifies the rest of stack frames to explore
- rest *runtime.Frames
- // n specifies the total number of stack frames
- n int
-}
-
-// formatCallerWithPathAndLine formats the caller in the form path/segment:
-// for output in the log
-func formatCallerWithPathAndLine() (path string) {
- if cursor := findFrame(); cursor != nil {
- t := newTraceFromFrames(*cursor, nil)
- return t.Loc()
- }
- return ""
-}
-
-var frameIgnorePattern = regexp.MustCompile(`github\.com/sirupsen/logrus`)
-
-// findFrames positions the stack pointer to the first
-// function that does not match the frameIngorePattern
-// and returns the rest of the stack frames
-func findFrame() *frameCursor {
- var buf [32]uintptr
- // Skip enough frames to start at user code.
- // This number is a mere hint to the following loop
- // to start as close to user code as possible and getting it right is not mandatory.
- // The skip count might need to get updated if the call to findFrame is
- // moved up/down the call stack
- n := runtime.Callers(4, buf[:])
- pcs := buf[:n]
- frames := runtime.CallersFrames(pcs)
- for i := 0; i < n; i++ {
- frame, _ := frames.Next()
- if !frameIgnorePattern.MatchString(frame.Function) {
- return &frameCursor{
- current: &frame,
- rest: frames,
- n: n,
- }
- }
- }
- return nil
-}
-
-func newTraceFromFrames(cursor frameCursor, err error) *trace.TraceErr {
- traces := make(trace.Traces, 0, cursor.n)
- if cursor.current != nil {
- traces = append(traces, frameToTrace(*cursor.current))
- }
- for {
- frame, more := cursor.rest.Next()
- traces = append(traces, frameToTrace(frame))
- if !more {
- break
- }
- }
- return &trace.TraceErr{
- Err: err,
- Traces: traces,
- }
-}
-
-func frameToTrace(frame runtime.Frame) trace.Trace {
- return trace.Trace{
- Func: frame.Function,
- Path: frame.File,
- Line: frame.Line,
- }
-}
diff --git a/lib/utils/log/slog.go b/lib/utils/log/slog.go
index 46f0e13627b3e..bfb34f4a94114 100644
--- a/lib/utils/log/slog.go
+++ b/lib/utils/log/slog.go
@@ -27,7 +27,6 @@ import (
"unicode"
"github.com/gravitational/trace"
- "github.com/sirupsen/logrus"
oteltrace "go.opentelemetry.io/otel/trace"
)
@@ -68,25 +67,6 @@ var SupportedLevelsText = []string{
slog.LevelError.String(),
}
-// SlogLevelToLogrusLevel converts a [slog.Level] to its equivalent
-// [logrus.Level].
-func SlogLevelToLogrusLevel(level slog.Level) logrus.Level {
- switch level {
- case TraceLevel:
- return logrus.TraceLevel
- case slog.LevelDebug:
- return logrus.DebugLevel
- case slog.LevelInfo:
- return logrus.InfoLevel
- case slog.LevelWarn:
- return logrus.WarnLevel
- case slog.LevelError:
- return logrus.ErrorLevel
- default:
- return logrus.FatalLevel
- }
-}
-
// DiscardHandler is a [slog.Handler] that discards all messages. It
// is more efficient than a [slog.Handler] which outputs to [io.Discard] since
// it performs zero formatting.
diff --git a/lib/utils/log/slog_text_handler.go b/lib/utils/log/slog_text_handler.go
index 7f93a388977bb..612615ba8582d 100644
--- a/lib/utils/log/slog_text_handler.go
+++ b/lib/utils/log/slog_text_handler.go
@@ -150,45 +150,12 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error {
// Processing fields in this manner allows users to
// configure the level and component position in the output.
- // This matches the behavior of the original logrus. All other
+ // This matches the behavior of the original logrus formatter. All other
// fields location in the output message are static.
for _, field := range s.cfg.ConfiguredFields {
switch field {
case LevelField:
- var color int
- var level string
- switch r.Level {
- case TraceLevel:
- level = "TRACE"
- color = gray
- case slog.LevelDebug:
- level = "DEBUG"
- color = gray
- case slog.LevelInfo:
- level = "INFO"
- color = blue
- case slog.LevelWarn:
- level = "WARN"
- color = yellow
- case slog.LevelError:
- level = "ERROR"
- color = red
- case slog.LevelError + 1:
- level = "FATAL"
- color = red
- default:
- color = blue
- level = r.Level.String()
- }
-
- if !s.cfg.EnableColors {
- color = noColor
- }
-
- level = padMax(level, defaultLevelPadding)
- if color != noColor {
- level = fmt.Sprintf("\u001B[%dm%s\u001B[0m", color, level)
- }
+ level := formatLevel(r.Level, s.cfg.EnableColors)
if rep == nil {
state.appendKey(slog.LevelKey)
@@ -211,12 +178,8 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error {
if attr.Key != teleport.ComponentKey {
return true
}
- component = fmt.Sprintf("[%v]", attr.Value)
- component = strings.ToUpper(padMax(component, s.cfg.Padding))
- if component[len(component)-1] != ' ' {
- component = component[:len(component)-1] + "]"
- }
+ component = formatComponent(attr.Value, s.cfg.Padding)
return false
})
@@ -271,6 +234,55 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error {
return err
}
+func formatLevel(value slog.Level, enableColors bool) string {
+ var color int
+ var level string
+ switch value {
+ case TraceLevel:
+ level = "TRACE"
+ color = gray
+ case slog.LevelDebug:
+ level = "DEBUG"
+ color = gray
+ case slog.LevelInfo:
+ level = "INFO"
+ color = blue
+ case slog.LevelWarn:
+ level = "WARN"
+ color = yellow
+ case slog.LevelError:
+ level = "ERROR"
+ color = red
+ case slog.LevelError + 1:
+ level = "FATAL"
+ color = red
+ default:
+ color = blue
+ level = value.String()
+ }
+
+ if !enableColors {
+ color = noColor
+ }
+
+ level = padMax(level, defaultLevelPadding)
+ if color != noColor {
+ level = fmt.Sprintf("\u001B[%dm%s\u001B[0m", color, level)
+ }
+
+ return level
+}
+
+func formatComponent(value slog.Value, padding int) string {
+ component := fmt.Sprintf("[%v]", value)
+ component = strings.ToUpper(padMax(component, padding))
+ if component[len(component)-1] != ' ' {
+ component = component[:len(component)-1] + "]"
+ }
+
+ return component
+}
+
func (s *SlogTextHandler) clone() *SlogTextHandler {
// We can't use assignment because we can't copy the mutex.
return &SlogTextHandler{
diff --git a/lib/utils/log/writer.go b/lib/utils/log/writer.go
deleted file mode 100644
index 77cf3037a8b66..0000000000000
--- a/lib/utils/log/writer.go
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package log
-
-import (
- "io"
- "sync"
-)
-
-// SharedWriter is an [io.Writer] implementation that protects
-// writes with a mutex. This allows a single [io.Writer] to be shared
-// by both logrus and slog without their output clobbering each other.
-type SharedWriter struct {
- mu sync.Mutex
- io.Writer
-}
-
-func (s *SharedWriter) Write(p []byte) (int, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- return s.Writer.Write(p)
-}
-
-// NewSharedWriter wraps the provided [io.Writer] in a writer that
-// is thread safe.
-func NewSharedWriter(w io.Writer) *SharedWriter {
- return &SharedWriter{Writer: w}
-}
diff --git a/rfd/0144-client-tools-updates.md b/rfd/0144-client-tools-updates.md
new file mode 100644
index 0000000000000..34ba8062971c8
--- /dev/null
+++ b/rfd/0144-client-tools-updates.md
@@ -0,0 +1,309 @@
+---
+authors: Russell Jones (rjones@goteleport.com) and Bernard Kim (bernard@goteleport.com)
+state: draft
+---
+
+# RFD 0144 - Client Tools Updates
+
+## Required Approvers
+
+* Engineering: @sclevine && @bernardjkim && @r0mant
+* Product: @klizhentas || @xinding33
+* Security: @reedloden
+
+## What/Why
+
+This RFD describes how client tools like `tsh` and `tctl` can be kept up to
+date, either using automatic updates or self-managed updates.
+
+Keeping client tools updated helps with security (fixes for known security
+vulnerabilities are pushed to endpoints), bugs (fixes for resolved issues are
+pushed to endpoints), and compatibility (users no longer have to learn and
+understand [Teleport component
+compatibility](https://goteleport.com/docs/upgrading/overview/#component-compatibility)
+rules).
+
+## Details
+
+### Summary
+
+Client tools like `tsh` and `tctl` will automatically download and install the
+required version for the Teleport cluster.
+
+Enrollment in automatic updates for client tools will be controlled at the
+cluster level. By default all Cloud clusters will be opted into automatic
+updates for client tools. Cluster administrators using MDM software like Jamf
+will be able opt-out manually manage updates.
+
+Self-hosted clusters will be be opted out, but have the option to use the same
+automatic update mechanism.
+
+Inspiration drawn from https://go.dev/doc/toolchain.
+
+### Implementation
+
+#### Client tools
+
+##### Automatic updates
+
+When `tsh login` is executed, client tools will check `/v1/webapi/find` to
+determine if automatic updates are enabled. If the cluster's required version
+differs from the current binary, client tools will download and re-execute
+using the version required by the cluster. All other `tsh` subcommands (like
+`tsh ssh ...`) will always use the downloaded version.
+
+The original client tools binaries won't be overwritten. Instead, an additional
+binary will be downloaded and stored in `~/.tsh/bin` with `0755` permissions.
+
+To validate the binaries have not been corrupted during download, a hash of the
+archive will be checked against the expected value. The expected hash value
+comes from the archive download path with `.sha256` appended.
+
+To enable concurrent operation of client tools, a locking mechanisms utilizing
+[syscall.Flock](https://pkg.go.dev/syscall#Flock) (for Linux and macOS) and
+[LockFileEx](https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-lockfileex)
+(for Windows) will be used.
+
+```
+$ tree ~/.tsh
+~/.tsh
+├── bin
+│ ├── tctl
+│ └── tsh
+├── current-profile
+├── keys
+│ └── proxy.example.com
+│ ├── cas
+│ │ └── example.com.pem
+│ ├── certs.pem
+│ ├── foo
+│ ├── foo-ssh
+│ │ └── example.com-cert.pub
+│ ├── foo-x509.pem
+│ └── foo.pub
+├── known_hosts
+└── proxy.example.com.yaml
+```
+
+Users can cancel client tools updates using `Ctrl-C`. This may be needed if the
+user is on a low bandwidth connection (LTE or public Wi-Fi), if the Teleport
+download server is inaccessible, or the user urgently needs to access the
+cluster and can not wait for the update to occur.
+
+```
+$ tsh login --proxy=proxy.example.com
+Client tools are out of date, updating to vX.Y.Z.
+Update progress: [▒▒▒▒▒▒ ] (Ctrl-C to cancel update)
+
+[...]
+```
+
+All archive downloads are targeted to the `cdn.teleport.dev` endpoint and depend
+on the operating system, platform, and edition. Where edition must be identified
+by the original client tools binary, URL pattern:
+`https://cdn.teleport.dev/teleport-{, ent-}v15.3.0-{linux, darwin, windows}-{amd64,arm64,arm,386}-{fips-}bin.tar.gz`
+
+An environment variable `TELEPORT_TOOLS_VERSION` will be introduced that can be
+`X.Y.Z` (use specific semver version) or `off` (do not update). This
+environment variable can be used as a emergency workaround for a known issue,
+pinning to a specific version in CI/CD, or for debugging.
+
+During re-execution, child process will inherit all environment variables and
+flags. `TELEPORT_TOOLS_VERSION=off` will be added during re-execution to
+prevent infinite loops.
+
+When `tctl` is used to connect to Auth Service running on the same host over
+`localhost`, `tctl` assumes a special administrator role that can perform all
+operations on a cluster. In this situation the expectation is for the version
+of `tctl` and `teleport` to match so automatic updates will not be used.
+
+> [!NOTE]
+> If a user connects to multiple root clusters, each running a different
+> version of Teleport, client tools will attempt to download the differing
+> version of Teleport each time the user performs a `tsh login`.
+>
+> In practice, the number of users impacted by this would be small. Customer
+> Cloud tenants would be on the same version and this feature is turned off by
+> default for self-hosted cluster.
+>
+> However, for those people in this situation, the recommendation would be to
+> use self-managed updates.
+
+##### Errors and warnings
+
+If cluster administrator has chosen not to enroll client tools in automatic
+updates and does not self-manage client tools updates as outlined in
+[Self-managed client tools updates](#self-managed-client-tools-updates), a
+series of warnings and errors with increasing urgency will be shown to the
+user.
+
+If the version of client tools is within the same major version as advertised
+by the cluster, a warning will be shown to urge the user to enroll in automatic
+updates. Warnings will not prevent the user from using client tools that are
+slightly out of date.
+
+```
+$ tsh login --proxy=proxy.example.com
+Warning: Client tools are out of date, update to vX.Y.Z.
+
+Update Teleport to vX.Y.Z from https://goteleport.com/download or your system
+package manager.
+
+Enroll in automatic updates to keep client tools like tsh and tctl
+automatically updated. https://goteleport.com/docs/upgrading/automatic-updates
+
+[...]
+```
+
+If the version of client tools is 1 major version below the version advertised
+by the cluster, a warning will be shown that indicates some functionality may
+not work.
+
+```
+$ tsh login --proxy=proxy.example.com
+WARNING: Client tools are 1 major version out of date, update to vX.Y.Z.
+
+Some functionality may not work. Update Teleport to vX.Y.Z from
+https://goteleport.com/download or your system package manager.
+
+Enroll in automatic updates to keep client tools like tsh and tctl
+automatically updated. https://goteleport.com/docs/upgrading/automatic-updates
+```
+
+If the version of client tools is 2 (or more) versions lower than the version
+advertised by the cluster or 1 (or more) version greater than the version
+advertised by the cluster, an error will be shown and will require the user to
+use the `--skip-version-check` flag.
+
+```
+$ tsh login --proxy=proxy.example.com
+ERROR: Client tools are N major versions out of date, update to vX.Y.Z.
+
+Your cluster requires {tsh,tctl} vX.Y.Z. Update Teleport from
+https://goteleport.com/download or your system package manager.
+
+Enroll in automatic updates to keep client tools like tsh and tctl
+automatically updated. https://goteleport.com/docs/upgrading/automatic-updates
+
+Use the "--skip-version-check" flag to bypass this check and attempt to connect
+to this cluster.
+```
+
+#### Self-managed client tools updates
+
+Cluster administrators that want to self-manage client tools updates will be
+able to get changes to client tools versions which can then be
+used to trigger other integrations (using MDM software like Jamf) to update the
+installed version of client tools on endpoints.
+
+By defining the `proxy` flag, we can use the get command without logging in.
+
+```
+$ tctl autoupdate client-tools status --proxy proxy.example.com --format json
+{
+ "mode": "enabled",
+ "target_version": "X.Y.Z"
+}
+```
+
+##### Cluster configuration
+
+Enrollment of clients in automatic updates will be enforced at the cluster
+level.
+
+The `autoupdate_config` resource will be updated to allow cluster
+administrators to turn client tools automatic updates `on` or `off`.
+A `autoupdate_version` resource will be added to allow cluster administrators
+to manage the version of tools pushed to clients.
+
+> [!NOTE]
+> Client tools configuration is broken into two resources to [prevent
+> updates](https://github.com/gravitational/teleport/blob/master/lib/modules/modules.go#L332-L355)
+> to `autoupdate_version` on Cloud.
+>
+> While Cloud customers will be able to use `autoupdate_config` to
+> turn client tools automatic updates `off` and self-manage updates, they will
+> not be able to control the version of client tools in `autoupdate_version`.
+> That will continue to be managed by the Teleport Cloud team.
+
+Both resources can either be updated directly or by using `tctl` helper
+functions.
+
+```yaml
+kind: autoupdate_config
+spec:
+ tools:
+ # tools mode allows to enable client tools updates or disable at the
+ # cluster level. Disable client tools automatic updates only if self-managed
+ # updates are in place.
+ mode: enabled|disabled
+```
+```
+$ tctl autoupdate client-tools enable
+client tools auto update mode has been changed
+
+$ tctl autoupdate client-tools disable
+client tools auto update mode has been changed
+```
+
+By default, all Cloud clusters will be opted into `tools.mode: enabled`. All
+self-hosted clusters will be opted into `tools.mode: disabled`.
+
+```yaml
+kind: autoupdate_version
+spec:
+ tools:
+ # target_version is the semver version of client tools the cluster will
+ # advertise.
+ target_version: X.Y.Z
+```
+```
+$ tctl autoupdate client-tools target X.Y.Z
+client tools auto update target version has been set
+
+$ tctl autoupdate client-tools target --clear
+client tools auto update target version has been cleared
+```
+
+For Cloud clusters, `target_version` will always be `X.Y.Z`, with the version
+controlled by the Cloud team.
+
+The above configuration will then be available from the unauthenticated
+proxy discovery endpoint `/v1/webapi/find` which clients will consult.
+Resources that store information about autoupdate and tools version are cached on
+the proxy side to minimize requests to the auth service. In case of an unhealthy
+cache state, the last known version of the resources should be used for the response.
+
+```
+$ curl https://proxy.example.com/v1/webapi/find | jq .auto_update
+{
+ "tools_auto_update": true,
+ "tools_version": "X.Y.Z",
+}
+```
+
+### Costs
+
+Some additional costs will be incurred as Teleport downloads will increase in
+frequency.
+
+### Out of scope
+
+How Cloud will push changes to `autoupdate_version` is out of scope for this
+RFD and will be handled by a separate Cloud specific RFD.
+
+Automatic updates for Teleport Connect are out of scope for this RFD as it uses
+a different install/update mechanism. For now it will call `tsh` with
+`TELEPORT_TOOLS_VERSION=off` until automatic updates support can be added to
+Connect.
+
+### Security
+
+The initial version of automatic updates will rely on TLS to establish
+connection authenticity to the Teleport download server. The authenticity of
+assets served from the download server is out of scope for this RFD. Cluster
+administrators concerned with the authenticity of assets served from the
+download server can use self-managed updates with system package managers which
+are signed.
+
+Phase 2 will use The Upgrade Framework (TUF) to implement secure updates.
diff --git a/rfd/cspell.json b/rfd/cspell.json
index 9982219bada5e..ee16b81f55872 100644
--- a/rfd/cspell.json
+++ b/rfd/cspell.json
@@ -201,6 +201,7 @@
"Statfs",
"Subconditions",
"Submatch",
+ "Sudia",
"Sycqsbqf",
"TBLPROPERTIES",
"TCSD",
@@ -211,6 +212,7 @@
"TPMs",
"Tablename",
"Teleconsole",
+ "Teleporter",
"Teleporting",
"Tiago",
"Tkachenko",
@@ -297,6 +299,7 @@
"behaviour",
"behaviours",
"benchtime",
+ "bernardjkim",
"bizz",
"bjoerger",
"blocklists",
@@ -667,6 +670,7 @@
"runtimes",
"russjones",
"ryanclark",
+ "sclevine",
"secretless",
"selfsubjectaccessreviews",
"selfsubjectrulesreviews",
@@ -731,6 +735,7 @@
"sudoersfile",
"supercede",
"syft",
+ "syscall",
"tablewriter",
"tailscale",
"targetting",
diff --git a/tool/tctl/common/collection.go b/tool/tctl/common/collection.go
index c1ea21addc2b6..4bf5d1629d0c9 100644
--- a/tool/tctl/common/collection.go
+++ b/tool/tctl/common/collection.go
@@ -1792,7 +1792,7 @@ type workloadIdentityCollection struct {
func (c *workloadIdentityCollection) resources() []types.Resource {
r := make([]types.Resource, 0, len(c.items))
for _, resource := range c.items {
- r = append(r, types.Resource153ToLegacy(resource))
+ r = append(r, types.ProtoResource153ToLegacy(resource))
}
return r
}
diff --git a/tool/tsh/common/git_config.go b/tool/tsh/common/git_config.go
index 89771735b30b3..6b703af251cee 100644
--- a/tool/tsh/common/git_config.go
+++ b/tool/tsh/common/git_config.go
@@ -124,11 +124,11 @@ func (c *gitConfigCommand) doUpdate(cf *CLIConf) error {
for _, url := range strings.Split(urls, "\n") {
u, err := parseGitSSHURL(url)
if err != nil {
- logger.DebugContext(cf.Context, "Skippig URL", "error", err, "url", url)
+ logger.DebugContext(cf.Context, "Skipping URL", "error", err, "url", url)
continue
}
if !u.isGitHub() {
- logger.DebugContext(cf.Context, "Skippig non-GitHub host", "host", u.Host)
+ logger.DebugContext(cf.Context, "Skipping non-GitHub host", "host", u.Host)
continue
}
diff --git a/web/packages/teleport/src/Apps/AddApp/AddApp.story.tsx b/web/packages/teleport/src/Apps/AddApp/AddApp.story.tsx
index db9ba0c4007ba..4ae3007934307 100644
--- a/web/packages/teleport/src/Apps/AddApp/AddApp.story.tsx
+++ b/web/packages/teleport/src/Apps/AddApp/AddApp.story.tsx
@@ -16,18 +16,50 @@
* along with this program. If not, see .
*/
+import { useState } from 'react';
+
+import { JoinToken } from 'teleport/services/joinToken';
+
import { AddApp } from './AddApp';
export default {
- title: 'Teleport/Apps/Add',
+ title: 'Teleport/Discover/Application/Web',
};
-export const Created = () => (
-
-);
+export const CreatedWithoutLabels = () => {
+ const [token, setToken] = useState();
+
+ return (
+ {
+ setToken(props.token);
+ return Promise.resolve(true);
+ }}
+ />
+ );
+};
+
+export const CreatedWithLabels = () => {
+ const [token, setToken] = useState();
-export const Loaded = () => {
- return ;
+ return (
+ {
+ setToken(props.token);
+ return Promise.resolve(true);
+ }}
+ />
+ );
};
export const Processing = () => (
@@ -72,8 +104,10 @@ const props = {
createJoinToken: () => Promise.resolve(null),
version: '5.0.0-dev',
reset: () => null,
+ labels: [],
+ setLabels: () => null,
attempt: {
- status: '',
+ status: 'success',
statusText: '',
} as any,
token: {
diff --git a/web/packages/teleport/src/Apps/AddApp/AddApp.tsx b/web/packages/teleport/src/Apps/AddApp/AddApp.tsx
index b40735fbce53d..7a82293d33a7a 100644
--- a/web/packages/teleport/src/Apps/AddApp/AddApp.tsx
+++ b/web/packages/teleport/src/Apps/AddApp/AddApp.tsx
@@ -44,6 +44,8 @@ export function AddApp({
setAutomatic,
isAuthTypeLocal,
token,
+ labels,
+ setLabels,
}: State & Props) {
return (
)}
{!automatic && (
diff --git a/web/packages/teleport/src/Apps/AddApp/Automatically.test.tsx b/web/packages/teleport/src/Apps/AddApp/Automatically.test.tsx
index 5761abdbcb42f..ece5ce843aa57 100644
--- a/web/packages/teleport/src/Apps/AddApp/Automatically.test.tsx
+++ b/web/packages/teleport/src/Apps/AddApp/Automatically.test.tsx
@@ -16,8 +16,6 @@
* along with this program. If not, see .
*/
-import { act } from '@testing-library/react';
-
import { fireEvent, render, screen } from 'design/utils/testing';
import { Automatically, createAppBashCommand } from './Automatically';
@@ -33,12 +31,14 @@ test('render command only after form submit', async () => {
roles: [],
content: '',
};
- render(
+ const { rerender } = render(
{}}
onCreate={() => Promise.resolve(true)}
+ labels={[]}
+ setLabels={() => null}
+ token={null}
/>
);
@@ -56,8 +56,21 @@ test('render command only after form submit', async () => {
target: { value: 'https://gravitational.com' },
});
+ rerender(
+ {}}
+ onCreate={() => Promise.resolve(true)}
+ labels={[]}
+ setLabels={() => null}
+ token={token}
+ />
+ );
+
// click button
- act(() => screen.getByRole('button', { name: /Generate Script/i }).click());
+ fireEvent.click(screen.getByRole('button', { name: /Generate Script/i }));
+
+ await screen.findByText(/Regenerate Script/i);
// after form submission should show the command
cmd = createAppBashCommand(token.id, 'app-name', 'https://gravitational.com');
diff --git a/web/packages/teleport/src/Apps/AddApp/Automatically.tsx b/web/packages/teleport/src/Apps/AddApp/Automatically.tsx
index de6669284f1ce..6e49916ef1261 100644
--- a/web/packages/teleport/src/Apps/AddApp/Automatically.tsx
+++ b/web/packages/teleport/src/Apps/AddApp/Automatically.tsx
@@ -20,6 +20,7 @@ import { KeyboardEvent, useEffect, useState } from 'react';
import {
Alert,
+ Box,
ButtonPrimary,
ButtonSecondary,
Flex,
@@ -33,24 +34,27 @@ import { Attempt } from 'shared/hooks/useAttemptNext';
import TextSelectCopy from 'teleport/components/TextSelectCopy';
import cfg from 'teleport/config';
+import { LabelsCreater } from 'teleport/Discover/Shared';
+import { ResourceLabelTooltip } from 'teleport/Discover/Shared/ResourceLabelTooltip';
+import { ResourceLabel } from 'teleport/services/agents';
import { State } from './useAddApp';
export function Automatically(props: Props) {
- const { onClose, attempt, token } = props;
+ const { onClose, attempt, token, labels, setLabels } = props;
const [name, setName] = useState('');
const [uri, setUri] = useState('');
const [cmd, setCmd] = useState('');
useEffect(() => {
- if (name && uri) {
+ if (name && uri && token) {
const cmd = createAppBashCommand(token.id, name, uri);
setCmd(cmd);
}
}, [token]);
- function handleRegenerate(validator: Validator) {
+ function onGenerateScript(validator: Validator) {
if (!validator.validate()) {
return;
}
@@ -58,25 +62,12 @@ export function Automatically(props: Props) {
props.onCreate(name, uri);
}
- function handleGenerate(validator: Validator) {
- if (!validator.validate()) {
- return;
- }
-
- const cmd = createAppBashCommand(token.id, name, uri);
- setCmd(cmd);
- }
-
function handleEnterPress(
e: KeyboardEvent,
validator: Validator
) {
if (e.key === 'Enter') {
- if (cmd) {
- handleRegenerate(validator);
- } else {
- handleGenerate(validator);
- }
+ onGenerateScript(validator);
}
}
@@ -96,6 +87,7 @@ export function Automatically(props: Props) {
mr="3"
onKeyPress={e => handleEnterPress(e, validator)}
onChange={e => setName(e.target.value.toLowerCase())}
+ disabled={attempt.status === 'processing'}
/>
handleEnterPress(e, validator)}
onChange={e => setUri(e.target.value)}
+ disabled={attempt.status === 'processing'}
/>
+
+
+ Add Labels (Optional)
+
+
+
+
{!cmd && (
Teleport can automatically set up application access. Provide
@@ -136,24 +145,13 @@ export function Automatically(props: Props) {
)}
- {!cmd && (
- handleGenerate(validator)}
- >
- Generate Script
-
- )}
- {cmd && (
- handleRegenerate(validator)}
- >
- Regenerate
-
- )}
+ onGenerateScript(validator)}
+ >
+ {cmd ? 'Regenerate Script' : 'Generate Script'}
+ ;
token: State['token'];
attempt: Attempt;
+ labels: ResourceLabel[];
+ setLabels(r: ResourceLabel[]): void;
};
diff --git a/web/packages/teleport/src/Apps/AddApp/useAddApp.ts b/web/packages/teleport/src/Apps/AddApp/useAddApp.ts
index be04b6cba17fd..cad6afd65c95c 100644
--- a/web/packages/teleport/src/Apps/AddApp/useAddApp.ts
+++ b/web/packages/teleport/src/Apps/AddApp/useAddApp.ts
@@ -20,6 +20,7 @@ import { useEffect, useState } from 'react';
import useAttempt from 'shared/hooks/useAttemptNext';
+import { ResourceLabel } from 'teleport/services/agents';
import type { JoinToken } from 'teleport/services/joinToken';
import TeleportContext from 'teleport/teleportContext';
@@ -31,14 +32,27 @@ export default function useAddApp(ctx: TeleportContext) {
const isEnterprise = ctx.isEnterprise;
const [automatic, setAutomatic] = useState(isEnterprise);
const [token, setToken] = useState();
+ const [labels, setLabels] = useState([]);
useEffect(() => {
- createToken();
- }, []);
+ // We don't want to create token on first render
+ // which defaults to the automatic tab because
+ // user may want to add labels.
+ if (!automatic) {
+ setLabels([]);
+ // When switching to manual tab, token can be re-used
+ // if token was already generated from automatic tab.
+ if (!token) {
+ createToken();
+ }
+ }
+ }, [automatic]);
function createToken() {
return run(() =>
- ctx.joinTokenService.fetchJoinToken({ roles: ['App'] }).then(setToken)
+ ctx.joinTokenService
+ .fetchJoinToken({ roles: ['App'], suggestedLabels: labels })
+ .then(setToken)
);
}
@@ -52,6 +66,8 @@ export default function useAddApp(ctx: TeleportContext) {
isAuthTypeLocal,
isEnterprise,
token,
+ labels,
+ setLabels,
};
}
diff --git a/web/packages/teleport/src/Discover/Shared/ResourceLabelTooltip/ResourceLabelTooltip.tsx b/web/packages/teleport/src/Discover/Shared/ResourceLabelTooltip/ResourceLabelTooltip.tsx
index 4feb605ae4692..f0d5ddc8abf5e 100644
--- a/web/packages/teleport/src/Discover/Shared/ResourceLabelTooltip/ResourceLabelTooltip.tsx
+++ b/web/packages/teleport/src/Discover/Shared/ResourceLabelTooltip/ResourceLabelTooltip.tsx
@@ -37,12 +37,36 @@ export function ResourceLabelTooltip({
resourceKind,
toolTipPosition,
}: {
- resourceKind: 'server' | 'eks' | 'rds' | 'kube' | 'db';
+ resourceKind: 'server' | 'eks' | 'rds' | 'kube' | 'db' | 'app';
toolTipPosition?: Position;
}) {
let tip;
switch (resourceKind) {
+ case 'app': {
+ tip = (
+ <>
+ Labels allow you to do the following:
+
+
+ Filter applications by labels when using tsh, tctl, or the web UI.
+
+
+ Restrict access to this application with{' '}
+
+ Teleport RBAC
+
+ . Only roles with app_labels that match
+ these labels will be allowed to access this application.
+