diff --git a/api/grpc/token/v1/token.pb.go b/api/grpc/token/v1/token.pb.go index 937ee5fee4..683a114c03 100644 --- a/api/grpc/token/v1/token.pb.go +++ b/api/grpc/token/v1/token.pb.go @@ -144,6 +144,50 @@ func (x *RefreshReq) GetVerified() bool { return false } +type RevokeReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RevokeReq) Reset() { + *x = RevokeReq{} + mi := &file_token_v1_token_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RevokeReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeReq) ProtoMessage() {} + +func (x *RevokeReq) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeReq.ProtoReflect.Descriptor instead. +func (*RevokeReq) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{2} +} + +func (x *RevokeReq) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + // If a token is not carrying any information itself, the type // field can be used to determine how to validate the token. // Also, different tokens can be encoded in different ways. @@ -158,7 +202,7 @@ type Token struct { func (x *Token) Reset() { *x = Token{} - mi := &file_token_v1_token_proto_msgTypes[2] + mi := &file_token_v1_token_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -170,7 +214,7 @@ func (x *Token) String() string { func (*Token) ProtoMessage() {} func (x *Token) ProtoReflect() protoreflect.Message { - mi := &file_token_v1_token_proto_msgTypes[2] + mi := &file_token_v1_token_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -183,7 +227,7 @@ func (x *Token) ProtoReflect() protoreflect.Message { // Deprecated: Use Token.ProtoReflect.Descriptor instead. func (*Token) Descriptor() ([]byte, []int) { - return file_token_v1_token_proto_rawDescGZIP(), []int{2} + return file_token_v1_token_proto_rawDescGZIP(), []int{3} } func (x *Token) GetAccessToken() string { @@ -207,6 +251,42 @@ func (x *Token) GetAccessType() string { return "" } +type RevokeRes struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RevokeRes) Reset() { + *x = RevokeRes{} + mi := &file_token_v1_token_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RevokeRes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeRes) ProtoMessage() {} + +func (x *RevokeRes) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeRes.ProtoReflect.Descriptor instead. +func (*RevokeRes) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{4} +} + var File_token_v1_token_proto protoreflect.FileDescriptor const file_token_v1_token_proto_rawDesc = "" + @@ -220,16 +300,20 @@ const file_token_v1_token_proto_rawDesc = "" + "\n" + "RefreshReq\x12#\n" + "\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\x12\x1a\n" + - "\bverified\x18\x02 \x01(\bR\bverified\"\x87\x01\n" + + "\bverified\x18\x02 \x01(\bR\bverified\"!\n" + + "\tRevokeReq\x12\x14\n" + + "\x05token\x18\x01 \x01(\tR\x05token\"\x87\x01\n" + "\x05Token\x12!\n" + "\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12(\n" + "\rrefresh_token\x18\x02 \x01(\tH\x00R\frefreshToken\x88\x01\x01\x12\x1f\n" + "\vaccess_type\x18\x03 \x01(\tR\n" + "accessTypeB\x10\n" + - "\x0e_refresh_token2r\n" + + "\x0e_refresh_token\"\v\n" + + "\tRevokeRes2\xa8\x01\n" + "\fTokenService\x12.\n" + "\x05Issue\x12\x12.token.v1.IssueReq\x1a\x0f.token.v1.Token\"\x00\x122\n" + - "\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3" + "\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00\x124\n" + + "\x06Revoke\x12\x13.token.v1.RevokeReq\x1a\x13.token.v1.RevokeRes\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3" var ( file_token_v1_token_proto_rawDescOnce sync.Once @@ -243,19 +327,23 @@ func file_token_v1_token_proto_rawDescGZIP() []byte { return file_token_v1_token_proto_rawDescData } -var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_token_v1_token_proto_goTypes = []any{ (*IssueReq)(nil), // 0: token.v1.IssueReq (*RefreshReq)(nil), // 1: token.v1.RefreshReq - (*Token)(nil), // 2: token.v1.Token + (*RevokeReq)(nil), // 2: token.v1.RevokeReq + (*Token)(nil), // 3: token.v1.Token + (*RevokeRes)(nil), // 4: token.v1.RevokeRes } var file_token_v1_token_proto_depIdxs = []int32{ 0, // 0: token.v1.TokenService.Issue:input_type -> token.v1.IssueReq 1, // 1: token.v1.TokenService.Refresh:input_type -> token.v1.RefreshReq - 2, // 2: token.v1.TokenService.Issue:output_type -> token.v1.Token - 2, // 3: token.v1.TokenService.Refresh:output_type -> token.v1.Token - 2, // [2:4] is the sub-list for method output_type - 0, // [0:2] is the sub-list for method input_type + 2, // 2: token.v1.TokenService.Revoke:input_type -> token.v1.RevokeReq + 3, // 3: token.v1.TokenService.Issue:output_type -> token.v1.Token + 3, // 4: token.v1.TokenService.Refresh:output_type -> token.v1.Token + 4, // 5: token.v1.TokenService.Revoke:output_type -> token.v1.RevokeRes + 3, // [3:6] is the sub-list for method output_type + 0, // [0:3] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -266,14 +354,14 @@ func file_token_v1_token_proto_init() { if File_token_v1_token_proto != nil { return } - file_token_v1_token_proto_msgTypes[2].OneofWrappers = []any{} + file_token_v1_token_proto_msgTypes[3].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_token_v1_token_proto_rawDesc), len(file_token_v1_token_proto_rawDesc)), NumEnums: 0, - NumMessages: 3, + NumMessages: 5, NumExtensions: 0, NumServices: 1, }, diff --git a/api/grpc/token/v1/token_grpc.pb.go b/api/grpc/token/v1/token_grpc.pb.go index 70ac6a7609..98a7a00c4a 100644 --- a/api/grpc/token/v1/token_grpc.pb.go +++ b/api/grpc/token/v1/token_grpc.pb.go @@ -24,6 +24,7 @@ const _ = grpc.SupportPackageIsVersion9 const ( TokenService_Issue_FullMethodName = "/token.v1.TokenService/Issue" TokenService_Refresh_FullMethodName = "/token.v1.TokenService/Refresh" + TokenService_Revoke_FullMethodName = "/token.v1.TokenService/Revoke" ) // TokenServiceClient is the client API for TokenService service. @@ -32,6 +33,7 @@ const ( type TokenServiceClient interface { Issue(ctx context.Context, in *IssueReq, opts ...grpc.CallOption) (*Token, error) Refresh(ctx context.Context, in *RefreshReq, opts ...grpc.CallOption) (*Token, error) + Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error) } type tokenServiceClient struct { @@ -62,12 +64,23 @@ func (c *tokenServiceClient) Refresh(ctx context.Context, in *RefreshReq, opts . return out, nil } +func (c *tokenServiceClient) Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RevokeRes) + err := c.cc.Invoke(ctx, TokenService_Revoke_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // TokenServiceServer is the server API for TokenService service. // All implementations must embed UnimplementedTokenServiceServer // for forward compatibility. type TokenServiceServer interface { Issue(context.Context, *IssueReq) (*Token, error) Refresh(context.Context, *RefreshReq) (*Token, error) + Revoke(context.Context, *RevokeReq) (*RevokeRes, error) mustEmbedUnimplementedTokenServiceServer() } @@ -84,6 +97,9 @@ func (UnimplementedTokenServiceServer) Issue(context.Context, *IssueReq) (*Token func (UnimplementedTokenServiceServer) Refresh(context.Context, *RefreshReq) (*Token, error) { return nil, status.Errorf(codes.Unimplemented, "method Refresh not implemented") } +func (UnimplementedTokenServiceServer) Revoke(context.Context, *RevokeReq) (*RevokeRes, error) { + return nil, status.Errorf(codes.Unimplemented, "method Revoke not implemented") +} func (UnimplementedTokenServiceServer) mustEmbedUnimplementedTokenServiceServer() {} func (UnimplementedTokenServiceServer) testEmbeddedByValue() {} @@ -141,6 +157,24 @@ func _TokenService_Refresh_Handler(srv interface{}, ctx context.Context, dec fun return interceptor(ctx, in, info, handler) } +func _TokenService_Revoke_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RevokeReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TokenServiceServer).Revoke(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: TokenService_Revoke_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TokenServiceServer).Revoke(ctx, req.(*RevokeReq)) + } + return interceptor(ctx, in, info, handler) +} + // TokenService_ServiceDesc is the grpc.ServiceDesc for TokenService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -156,6 +190,10 @@ var TokenService_ServiceDesc = grpc.ServiceDesc{ MethodName: "Refresh", Handler: _TokenService_Refresh_Handler, }, + { + MethodName: "Revoke", + Handler: _TokenService_Revoke_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "token/v1/token.proto", diff --git a/auth/api/grpc/token/client.go b/auth/api/grpc/token/client.go index 25c3bf62fe..96cbdb2d10 100644 --- a/auth/api/grpc/token/client.go +++ b/auth/api/grpc/token/client.go @@ -20,6 +20,7 @@ const tokenSvcName = "token.v1.TokenService" type tokenGrpcClient struct { issue endpoint.Endpoint refresh endpoint.Endpoint + revoke endpoint.Endpoint timeout time.Duration } @@ -44,6 +45,14 @@ func NewTokenClient(conn *grpc.ClientConn, timeout time.Duration) grpcTokenV1.To decodeRefreshResponse, grpcTokenV1.Token{}, ).Endpoint(), + revoke: kitgrpc.NewClient( + conn, + tokenSvcName, + "Revoke", + encodeRevokeRequest, + decodeRevokeResponse, + grpcTokenV1.RevokeRes{}, + ).Endpoint(), timeout: timeout, } } @@ -97,3 +106,23 @@ func encodeRefreshRequest(_ context.Context, grpcReq any) (any, error) { func decodeRefreshResponse(_ context.Context, grpcRes any) (any, error) { return grpcRes, nil } + +func (client tokenGrpcClient) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq, _ ...grpc.CallOption) (*grpcTokenV1.RevokeRes, error) { + ctx, cancel := context.WithTimeout(ctx, client.timeout) + defer cancel() + + res, err := client.revoke(ctx, revokeReq{token: req.GetToken()}) + if err != nil { + return &grpcTokenV1.RevokeRes{}, grpcapi.DecodeError(err) + } + return res.(*grpcTokenV1.RevokeRes), nil +} + +func encodeRevokeRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(revokeReq) + return &grpcTokenV1.RevokeReq{Token: req.token}, nil +} + +func decodeRevokeResponse(_ context.Context, grpcRes any) (any, error) { + return grpcRes, nil +} diff --git a/auth/api/grpc/token/endpoint.go b/auth/api/grpc/token/endpoint.go index b03e42ae53..dc8cca3611 100644 --- a/auth/api/grpc/token/endpoint.go +++ b/auth/api/grpc/token/endpoint.go @@ -56,3 +56,18 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint { return ret, nil } } + +func revokeEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(revokeReq) + if err := req.validate(); err != nil { + return nil, err + } + err := svc.RevokeToken(ctx, req.token) + if err != nil { + return nil, err + } + + return nil, nil + } +} diff --git a/auth/api/grpc/token/endpoint_test.go b/auth/api/grpc/token/endpoint_test.go index be1820e42c..08142d880a 100644 --- a/auth/api/grpc/token/endpoint_test.go +++ b/auth/api/grpc/token/endpoint_test.go @@ -24,24 +24,9 @@ import ( ) const ( - port = 8082 - secret = "secret" - email = "test@example.com" - id = "testID" - clientsType = "clients" - usersType = "users" - description = "Description" - groupName = "smqx" - adminPermission = "admin" - - authoritiesObj = "authorities" - memberRelation = "member" - loginDuration = 30 * time.Minute - refreshDuration = 24 * time.Hour - invalidDuration = 7 * 24 * time.Hour - validToken = "valid" - inValidToken = "invalid" - validPolicy = "valid" + port = 8082 + validToken = "valid" + inValidToken = "invalid" ) var ( @@ -63,9 +48,9 @@ func startGRPCServer(svc auth.Service, port int) *grpc.Server { func TestIssue(t *testing.T) { conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) - defer conn.Close() assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() cases := []struct { desc string @@ -127,9 +112,9 @@ func TestIssue(t *testing.T) { func TestRefresh(t *testing.T) { conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) - defer conn.Close() assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() cases := []struct { desc string @@ -167,3 +152,44 @@ func TestRefresh(t *testing.T) { svcCall.Unset() } } + +func TestRevoke(t *testing.T) { + conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() + + cases := []struct { + desc string + token string + err error + }{ + { + desc: "revoke token with valid token", + token: validToken, + err: nil, + }, + { + desc: "revoke token with invalid token", + token: inValidToken, + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke token with empty token", + token: "", + err: apiutil.ErrMissingSecret, + }, + { + desc: "revoke already revoked token", + token: validToken, + err: svcerr.ErrConflict, + }, + } + + for _, tc := range cases { + svcCall := svc.On("RevokeToken", mock.Anything, tc.token).Return(tc.err) + _, err := grpcClient.Revoke(context.Background(), &grpcTokenV1.RevokeReq{Token: tc.token}) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svcCall.Unset() + } +} diff --git a/auth/api/grpc/token/requests.go b/auth/api/grpc/token/requests.go index a5ab3c0949..a87d2b83e4 100644 --- a/auth/api/grpc/token/requests.go +++ b/auth/api/grpc/token/requests.go @@ -38,3 +38,15 @@ func (req refreshReq) validate() error { return nil } + +type revokeReq struct { + token string +} + +func (req revokeReq) validate() error { + if req.token == "" { + return apiutil.ErrMissingSecret + } + + return nil +} diff --git a/auth/api/grpc/token/server.go b/auth/api/grpc/token/server.go index 319e46e6ec..7e5263994d 100644 --- a/auth/api/grpc/token/server.go +++ b/auth/api/grpc/token/server.go @@ -18,6 +18,7 @@ type tokenGrpcServer struct { grpcTokenV1.UnimplementedTokenServiceServer issue kitgrpc.Handler refresh kitgrpc.Handler + revoke kitgrpc.Handler } // NewAuthServer returns new AuthnServiceServer instance. @@ -33,6 +34,11 @@ func NewTokenServer(svc auth.Service) grpcTokenV1.TokenServiceServer { decodeRefreshRequest, encodeIssueResponse, ), + revoke: kitgrpc.NewServer( + (revokeEndpoint(svc)), + decodeRevokeRequest, + encodeRevokeResponse, + ), } } @@ -76,3 +82,20 @@ func encodeIssueResponse(_ context.Context, grpcRes any) (any, error) { AccessType: res.accessType, }, nil } + +func (s *tokenGrpcServer) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq) (*grpcTokenV1.RevokeRes, error) { + _, res, err := s.revoke.ServeGRPC(ctx, req) + if err != nil { + return nil, grpcapi.EncodeError(err) + } + return res.(*grpcTokenV1.RevokeRes), nil +} + +func decodeRevokeRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(*grpcTokenV1.RevokeReq) + return revokeReq{token: req.GetToken()}, nil +} + +func encodeRevokeResponse(_ context.Context, grpcRes any) (any, error) { + return &grpcTokenV1.RevokeRes{}, nil +} diff --git a/auth/api/http/keys/endpoint.go b/auth/api/http/keys/endpoint.go index d168acd5be..583485509a 100644 --- a/auth/api/http/keys/endpoint.go +++ b/auth/api/http/keys/endpoint.go @@ -85,3 +85,18 @@ func revokeEndpoint(svc auth.Service) endpoint.Endpoint { return revokeKeyRes{}, nil } } + +func revokeTokenEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(revokeTokenReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.RevokeToken(ctx, req.token); err != nil { + return nil, err + } + + return revokeKeyRes{}, nil + } +} diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index d9e996f1c2..8b77f0fb0c 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -31,7 +31,6 @@ const ( secret = "secret" contentType = "application/json" id = "123e4567-e89b-12d3-a456-000000000001" - email = "user@example.com" loginDuration = 30 * time.Minute refreshDuration = 24 * time.Hour invalidDuration = 7 * 24 * time.Hour @@ -80,7 +79,9 @@ func newService() auth.Service { idProvider := uuid.NewMock() pService := new(policymocks.Service) pEvaluator = new(policymocks.Evaluator) - t := jwt.New([]byte(secret)) + repo := new(mocks.TokensRepository) + tcache := new(mocks.TokensCache) + t := jwt.New([]byte(secret), repo, tcache) return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration) } @@ -352,3 +353,56 @@ func TestRevoke(t *testing.T) { repoCall.Unset() } } + +func TestRevokeToken(t *testing.T) { + svc := new(mocks.Service) + + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + cases := []struct { + desc string + id string + token string + err error + status int + }{ + { + desc: "revoke an existing token", + token: "token", + status: http.StatusNoContent, + }, + { + desc: "revoke a non-existing token", + token: "token", + err: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "revoke invalid token", + token: "wrong", + err: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "revoke empty token", + token: "", + status: http.StatusUnauthorized, + }, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodDelete, + url: fmt.Sprintf("%s/keys/", ts.URL), + token: tc.token, + } + svcCall := svc.On("RevokeToken", mock.Anything, tc.token).Return(tc.err) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + } +} diff --git a/auth/api/http/keys/requests.go b/auth/api/http/keys/requests.go index 20568427c8..a25d1af090 100644 --- a/auth/api/http/keys/requests.go +++ b/auth/api/http/keys/requests.go @@ -46,3 +46,15 @@ func (req keyReq) validate() error { } return nil } + +type revokeTokenReq struct { + token string +} + +func (req revokeTokenReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + + return nil +} diff --git a/auth/api/http/keys/requests_test.go b/auth/api/http/keys/requests_test.go index cf44337349..4c717596bc 100644 --- a/auth/api/http/keys/requests_test.go +++ b/auth/api/http/keys/requests_test.go @@ -86,3 +86,30 @@ func TestKeyReqValidate(t *testing.T) { assert.Equal(t, tc.err, err) } } + +func TestRevokeTokenReqValidate(t *testing.T) { + cases := []struct { + desc string + req revokeTokenReq + err error + }{ + { + desc: "valid request", + req: revokeTokenReq{ + token: valid, + }, + err: nil, + }, + { + desc: "empty token", + req: revokeTokenReq{ + token: "", + }, + err: apiutil.ErrBearerToken, + }, + } + for _, tc := range cases { + err := tc.req.validate() + assert.Equal(t, tc.err, err) + } +} diff --git a/auth/api/http/keys/transport.go b/auth/api/http/keys/transport.go index 689fc32554..fbd7688c40 100644 --- a/auth/api/http/keys/transport.go +++ b/auth/api/http/keys/transport.go @@ -33,6 +33,13 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux { opts..., ).ServeHTTP) + r.Delete("/", kithttp.NewServer( + revokeTokenEndpoint(svc), + decodeRevokeTokenReq, + api.EncodeResponse, + opts..., + ).ServeHTTP) + r.Get("/{id}", kithttp.NewServer( (retrieveEndpoint(svc)), decodeKeyReq, @@ -70,3 +77,11 @@ func decodeKeyReq(_ context.Context, r *http.Request) (any, error) { } return req, nil } + +func decodeRevokeTokenReq(_ context.Context, r *http.Request) (interface{}, error) { + req := revokeTokenReq{ + token: apiutil.ExtractBearerToken(r), + } + + return req, nil +} diff --git a/auth/cache/doc.go b/auth/cache/doc.go index 42396c9830..571a973d0f 100644 --- a/auth/cache/doc.go +++ b/auth/cache/doc.go @@ -1,4 +1,6 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 +// Package cache contains the domain concept definitions needed to +// support SuperMQ auth cache service functionality. package cache diff --git a/auth/cache/setup_test.go b/auth/cache/setup_test.go new file mode 100644 index 0000000000..44ae0d232d --- /dev/null +++ b/auth/cache/setup_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "testing" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/redis/go-redis/v9" +) + +var ( + redisURL string + redisClient *redis.Client +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "redis", + Tag: "7.2.4-alpine", + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + handleInterrupt(pool, container) + + redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp")) + opts, err := redis.ParseURL(redisURL) + if err != nil { + log.Fatalf("Could not parse redis URL: %s", err) + } + + if err := pool.Retry(func() error { + redisClient = redis.NewClient(opts) + + return redisClient.Ping(context.Background()).Err() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + code := m.Run() + + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} + +func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + os.Exit(0) + }() +} diff --git a/auth/cache/tokens.go b/auth/cache/tokens.go new file mode 100644 index 0000000000..7cd6387272 --- /dev/null +++ b/auth/cache/tokens.go @@ -0,0 +1,56 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/redis/go-redis/v9" +) + +const defKey = "revoked_tokens" + +var _ auth.TokensCache = (*tokensCache)(nil) + +type tokensCache struct { + client *redis.Client + keyDuration time.Duration +} + +// NewTokensCache returns redis auth cache implementation. +func NewTokensCache(client *redis.Client, duration time.Duration) auth.TokensCache { + return &tokensCache{ + client: client, + keyDuration: duration, + } +} + +func (tc *tokensCache) Save(ctx context.Context, _, value string) error { + if err := tc.client.SAdd(ctx, defKey, value).Err(); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (tc *tokensCache) Contains(ctx context.Context, _, value string) bool { + ok, err := tc.client.SIsMember(ctx, defKey, value).Result() + if err != nil { + return false + } + + return ok +} + +func (tc *tokensCache) Remove(ctx context.Context, value string) error { + if err := tc.client.SRem(ctx, defKey, value).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} diff --git a/auth/cache/tokens_test.go b/auth/cache/tokens_test.go new file mode 100644 index 0000000000..a59b77aff0 --- /dev/null +++ b/auth/cache/tokens_test.go @@ -0,0 +1,174 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/cache" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/stretchr/testify/assert" +) + +func setupRedisTokensClient() auth.TokensCache { + return cache.NewTokensCache(redisClient, 10*time.Minute) +} + +func TestTokenSave(t *testing.T) { + redisClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + key := auth.Key{ + ID: testsutil.GenerateUUID(t), + } + cases := []struct { + desc string + key auth.Key + err error + }{ + { + desc: "Save token", + key: key, + err: nil, + }, + { + desc: "Save already cached policy", + key: key, + err: nil, + }, + { + desc: "Save another policy", + key: auth.Key{ + ID: testsutil.GenerateUUID(t), + }, + err: nil, + }, + { + desc: "Save policy with long key", + key: auth.Key{ + ID: strings.Repeat("a", 513*1024*1024), + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "Save policy with empty key", + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.Save(context.Background(), "", tc.key.ID) + if err == nil { + ok := tokensCache.Contains(context.Background(), "", tc.key.ID) + assert.True(t, ok) + } + assert.True(t, errors.Contains(err, tc.err)) + }) + } +} + +func TestTokenContains(t *testing.T) { + redisClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + key := auth.Key{ + ID: testsutil.GenerateUUID(t), + } + + err := tokensCache.Save(context.Background(), "", key.ID) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + key auth.Key + ok bool + }{ + { + desc: "Contains existing key", + key: key, + ok: true, + }, + { + desc: "Contains non existing key", + key: auth.Key{ + ID: testsutil.GenerateUUID(t), + }, + }, + { + desc: "Contains key with long id", + key: auth.Key{ + ID: strings.Repeat("a", 513*1024*1024), + }, + }, + { + desc: "Contains key with empty id", + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ok := tokensCache.Contains(context.Background(), "", tc.key.ID) + assert.Equal(t, tc.ok, ok) + }) + } +} + +func TestTokenRemove(t *testing.T) { + redisClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + num := 10 + var ids []string + for range num { + id := testsutil.GenerateUUID(t) + err := tokensCache.Save(context.Background(), "", id) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + ids = append(ids, id) + } + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "Remove an existing token from cache", + id: ids[0], + err: nil, + }, + { + desc: "Remove token with empty id from cache", + err: nil, + }, + { + desc: "Remove non existing id from cache", + id: testsutil.GenerateUUID(t), + err: nil, + }, + { + desc: "Remove token with long id from cache", + id: strings.Repeat("a", 513*1024*1024), + err: repoerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.Remove(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err)) + if err == nil { + ok := tokensCache.Contains(context.Background(), "", tc.id) + assert.False(t, ok) + } + }) + } +} diff --git a/auth/jwt/token_test.go b/auth/jwt/token_test.go index f221653ab6..eca68c4ac3 100644 --- a/auth/jwt/token_test.go +++ b/auth/jwt/token_test.go @@ -4,14 +4,17 @@ package jwt_test import ( + "context" "fmt" "testing" "time" "github.com/absmach/supermq/auth" authjwt "github.com/absmach/supermq/auth/jwt" + "github.com/absmach/supermq/auth/mocks" "github.com/absmach/supermq/internal/testsutil" "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" @@ -51,7 +54,9 @@ func newToken(issuerName string, key auth.Key) string { } func TestIssue(t *testing.T) { - tokenizer := authjwt.New([]byte(secret)) + repo := new(mocks.TokensRepository) + cache := new(mocks.TokensCache) + tokenizer := authjwt.New([]byte(secret), repo, cache) cases := []struct { desc string @@ -128,7 +133,9 @@ func TestIssue(t *testing.T) { } func TestParse(t *testing.T) { - tokenizer := authjwt.New([]byte(secret)) + repo := new(mocks.TokensRepository) + cache := new(mocks.TokensCache) + tokenizer := authjwt.New([]byte(secret), repo, cache) token, err := tokenizer.Issue(key()) require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) @@ -144,27 +151,35 @@ func TestParse(t *testing.T) { expToken, err := tokenizer.Issue(expKey) require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err)) + emptyDomainKey := key() + emptyDomainToken, err := tokenizer.Issue(emptyDomainKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + emptySubjectKey := key() emptySubjectKey.Subject = "" emptySubjectToken, err := tokenizer.Issue(emptySubjectKey) require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) - emptyTypeKey := key() - emptyTypeKey.Type = auth.KeyType(auth.InvitationKey + 1) - emptyTypeToken, err := tokenizer.Issue(emptyTypeKey) - require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) - emptyKey := key() emptyKey.Subject = "" + emptyToken, err := tokenizer.Issue(emptyKey) require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) inValidToken := newToken("invalid", key()) + refreshKey := key() + refreshKey.Type = auth.RefreshKey + refreshToken, err := tokenizer.Issue(refreshKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + cases := []struct { - desc string - key auth.Key - token string - err error + desc string + key auth.Key + token string + cacheContains bool + repoContains bool + cacheSave error + err error }{ { desc: "parse valid key", @@ -202,6 +217,12 @@ func TestParse(t *testing.T) { token: newToken(issuerName, key()), err: authjwt.ErrJSONHandle, }, + { + desc: "parse token with empty domain", + key: emptyDomainKey, + token: emptyDomainToken, + err: nil, + }, { desc: "parse token with empty subject", key: emptySubjectKey, @@ -209,21 +230,186 @@ func TestParse(t *testing.T) { err: nil, }, { - desc: "parse token with empty type", - key: emptyTypeKey, - token: emptyTypeToken, - err: errors.ErrAuthentication, + desc: "parse token with empty domain and subject", + key: emptyKey, + token: emptyToken, + err: nil, + }, + { + desc: "parse refresh token", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: false, + err: nil, + }, + { + desc: "parse revoked refresh token in cache", + key: refreshKey, + token: refreshToken, + cacheContains: true, + repoContains: false, + err: svcerr.ErrAuthentication, + }, + { + desc: "parse revoked refresh token not in cache", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: true, + err: svcerr.ErrAuthentication, + }, + { + desc: "parse revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: true, + cacheSave: repoerr.ErrCreateEntity, + err: svcerr.ErrAuthentication, }, } for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - key, err := tokenizer.Parse(tc.token) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) - if err == nil { - assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key)) - } - }) + cacheCall := cache.On("Contains", context.Background(), "", tc.key.ID).Return(tc.cacheContains) + repoCall := repo.On("Contains", context.Background(), tc.key.ID).Return(tc.repoContains) + cacheCall1 := cache.On("Save", context.Background(), "", tc.key.ID).Return(tc.cacheSave) + key, err := tokenizer.Parse(context.Background(), tc.token) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key)) + } + cacheCall.Unset() + repoCall.Unset() + cacheCall1.Unset() + } +} + +func TestRevoke(t *testing.T) { + repo := new(mocks.TokensRepository) + cache := new(mocks.TokensCache) + tokenizer := authjwt.New([]byte(secret), repo, cache) + + token, err := tokenizer.Issue(key()) + require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) + + apiKey := key() + apiKey.Type = auth.APIKey + apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) + apiToken, err := tokenizer.Issue(apiKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + expKey := key() + expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) + expToken, err := tokenizer.Issue(expKey) + require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err)) + + emptyDomainKey := key() + emptyDomainToken, err := tokenizer.Issue(emptyDomainKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + emptySubjectKey := key() + emptySubjectKey.Subject = "" + emptySubjectToken, err := tokenizer.Issue(emptySubjectKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + emptyKey := key() + emptyKey.Subject = "" + emptyToken, err := tokenizer.Issue(emptyKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + inValidToken := newToken("invalid", key()) + + refreshKey := key() + refreshKey.Type = auth.RefreshKey + refreshToken, err := tokenizer.Issue(refreshKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + cases := []struct { + desc string + key auth.Key + token string + repoErr error + cacheErr error + err error + }{ + { + desc: "revoke valid key", + key: key(), + token: token, + err: nil, + }, + { + desc: "revoke invalid key", + key: auth.Key{}, + token: "invalid", + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke expired key", + key: auth.Key{}, + token: expToken, + err: auth.ErrExpiry, + }, + { + desc: "revoke expired API key", + key: apiKey, + token: apiToken, + err: auth.ErrExpiry, + }, + { + desc: "revoke token with invalid issuer", + key: auth.Key{}, + token: inValidToken, + err: errInvalidIssuer, + }, + { + desc: "revoke token with invalid content", + key: auth.Key{}, + token: newToken(issuerName, key()), + err: authjwt.ErrJSONHandle, + }, + { + desc: "revoke token with empty domain", + key: emptyDomainKey, + token: emptyDomainToken, + err: nil, + }, + { + desc: "revoke token with empty subject", + key: emptySubjectKey, + token: emptySubjectToken, + err: nil, + }, + { + desc: "revoke token with empty domain and subject", + key: emptyKey, + token: emptyToken, + err: nil, + }, + { + desc: "revoke refresh token", + key: refreshKey, + token: refreshToken, + err: nil, + }, + { + desc: "revoke revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + repoErr: nil, + cacheErr: repoerr.ErrCreateEntity, + err: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + repoCall := repo.On("Save", context.Background(), tc.key.ID).Return(tc.repoErr) + cacheCall := cache.On("Save", context.Background(), "", tc.key.ID).Return(tc.cacheErr) + err := tokenizer.Revoke(context.Background(), tc.token) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) + cacheCall.Unset() + repoCall.Unset() } } diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go index 9f0e24838b..86080b18f1 100644 --- a/auth/jwt/tokenizer.go +++ b/auth/jwt/tokenizer.go @@ -31,6 +31,8 @@ var ( ErrValidateJWTToken = errors.New("failed to validate jwt token") // ErrJSONHandle indicates an error in handling JSON. ErrJSONHandle = errors.New("failed to perform operation JSON") + // errRevokedToken indicates that the token is revoked. + errRevokedToken = errors.New("token is revoked") ) const ( @@ -46,14 +48,18 @@ const ( type tokenizer struct { secret []byte + cache auth.TokensCache + repo auth.TokensRepository } var _ auth.Tokenizer = (*tokenizer)(nil) -// NewRepository instantiates an implementation of Token repository. -func New(secret []byte) auth.Tokenizer { +// New instantiates an implementation of Tokenizer service. +func New(secret []byte, repo auth.TokensRepository, cache auth.TokensCache) auth.Tokenizer { return &tokenizer{ secret: secret, + repo: repo, + cache: cache, } } @@ -83,7 +89,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) { return string(signedTkn), nil } -func (tok *tokenizer) Parse(token string) (auth.Key, error) { +func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) { tkn, err := tok.validateToken(token) if err != nil { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) @@ -94,9 +100,48 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) } + if key.Type == auth.RefreshKey { + switch tok.cache.Contains(ctx, "", key.ID) { + case true: + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedToken) + default: + if ok := tok.repo.Contains(ctx, key.ID); ok { + if err := tok.cache.Save(ctx, "", key.ID); err != nil { + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedToken) + } + } + } + return key, nil } +func (tok *tokenizer) Revoke(ctx context.Context, token string) error { + tkn, err := tok.validateToken(token) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + key, err := toKey(tkn) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + if key.Type == auth.RefreshKey { + if err := tok.repo.Save(ctx, key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + if err := tok.cache.Save(ctx, "", key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + } + + return nil +} + func (tok *tokenizer) validateToken(token string) (jwt.Token, error) { tkn, err := jwt.Parse( []byte(token), diff --git a/auth/middleware/logging.go b/auth/middleware/logging.go index 5b6ccc5cd0..0e248b672f 100644 --- a/auth/middleware/logging.go +++ b/auth/middleware/logging.go @@ -46,6 +46,22 @@ func (lm *loggingMiddleware) Issue(ctx context.Context, token string, key auth.K return lm.svc.Issue(ctx, token, key) } +func (lm *loggingMiddleware) RevokeToken(ctx context.Context, token string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Revoke token failed to complete successfully", args...) + return + } + lm.logger.Info("Revoke token completed successfully", args...) + }(time.Now()) + + return lm.svc.RevokeToken(ctx, token) +} + func (lm *loggingMiddleware) Revoke(ctx context.Context, token, id string) (err error) { defer func(begin time.Time) { args := []any{ diff --git a/auth/middleware/metrics.go b/auth/middleware/metrics.go index dbbebefc67..ec7bfbbc38 100644 --- a/auth/middleware/metrics.go +++ b/auth/middleware/metrics.go @@ -40,6 +40,15 @@ func (ms *metricsMiddleware) Issue(ctx context.Context, token string, key auth.K return ms.svc.Issue(ctx, token, key) } +func (ms *metricsMiddleware) RevokeToken(ctx context.Context, token string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_token").Add(1) + ms.latency.With("method", "revoke_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RevokeToken(ctx, token) +} + func (ms *metricsMiddleware) Revoke(ctx context.Context, token, id string) error { defer func(begin time.Time) { ms.counter.With("method", "revoke_key").Add(1) diff --git a/auth/middleware/tracing.go b/auth/middleware/tracing.go index b20828eefa..7af2391fb5 100644 --- a/auth/middleware/tracing.go +++ b/auth/middleware/tracing.go @@ -36,6 +36,13 @@ func (tm *tracingMiddleware) Issue(ctx context.Context, token string, key auth.K return tm.svc.Issue(ctx, token, key) } +func (tm *tracingMiddleware) RevokeToken(ctx context.Context, token string) error { + ctx, span := tm.tracer.Start(ctx, "revoke_token") + defer span.End() + + return tm.svc.RevokeToken(ctx, token) +} + func (tm *tracingMiddleware) Revoke(ctx context.Context, token, id string) error { ctx, span := tm.tracer.Start(ctx, "revoke", trace.WithAttributes( attribute.String("id", id), diff --git a/auth/mocks/service.go b/auth/mocks/service.go index c2f688bbe1..8fab6d5195 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -1298,6 +1298,63 @@ func (_c *Service_RevokePATSecret_Call) RunAndReturn(run func(ctx context.Contex return _c } +// RevokeToken provides a mock function for the type Service +func (_mock *Service) RevokeToken(ctx context.Context, token string) error { + ret := _mock.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for RevokeToken") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, token) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RevokeToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeToken' +type Service_RevokeToken_Call struct { + *mock.Call +} + +// RevokeToken is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *Service_Expecter) RevokeToken(ctx interface{}, token interface{}) *Service_RevokeToken_Call { + return &Service_RevokeToken_Call{Call: _e.mock.On("RevokeToken", ctx, token)} +} + +func (_c *Service_RevokeToken_Call) Run(run func(ctx context.Context, token string)) *Service_RevokeToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_RevokeToken_Call) Return(err error) *Service_RevokeToken_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RevokeToken_Call) RunAndReturn(run func(ctx context.Context, token string) error) *Service_RevokeToken_Call { + _c.Call.Return(run) + return _c +} + // UpdatePATDescription provides a mock function for the type Service func (_mock *Service) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) { ret := _mock.Called(ctx, token, patID, description) diff --git a/auth/mocks/token_client.go b/auth/mocks/token_client.go index 025065092f..756512216a 100644 --- a/auth/mocks/token_client.go +++ b/auth/mocks/token_client.go @@ -208,3 +208,86 @@ func (_c *TokenServiceClient_Refresh_Call) RunAndReturn(run func(ctx context.Con _c.Call.Return(run) return _c } + +// Revoke provides a mock function for the type TokenServiceClient +func (_mock *TokenServiceClient) Revoke(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, in, opts) + } else { + tmpRet = _mock.Called(ctx, in) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for Revoke") + } + + var r0 *v1.RevokeRes + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) (*v1.RevokeRes, error)); ok { + return returnFunc(ctx, in, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) *v1.RevokeRes); ok { + r0 = returnFunc(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.RevokeRes) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) error); ok { + r1 = returnFunc(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// TokenServiceClient_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke' +type TokenServiceClient_Revoke_Call struct { + *mock.Call +} + +// Revoke is a helper method to define mock.On call +// - ctx context.Context +// - in *v1.RevokeReq +// - opts ...grpc.CallOption +func (_e *TokenServiceClient_Expecter) Revoke(ctx interface{}, in interface{}, opts ...interface{}) *TokenServiceClient_Revoke_Call { + return &TokenServiceClient_Revoke_Call{Call: _e.mock.On("Revoke", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *TokenServiceClient_Revoke_Call) Run(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption)) *TokenServiceClient_Revoke_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *v1.RevokeReq + if args[1] != nil { + arg1 = args[1].(*v1.RevokeReq) + } + var arg2 []grpc.CallOption + var variadicArgs []grpc.CallOption + if len(args) > 2 { + variadicArgs = args[2].([]grpc.CallOption) + } + arg2 = variadicArgs + run( + arg0, + arg1, + arg2..., + ) + }) + return _c +} + +func (_c *TokenServiceClient_Revoke_Call) Return(revokeRes *v1.RevokeRes, err error) *TokenServiceClient_Revoke_Call { + _c.Call.Return(revokeRes, err) + return _c +} + +func (_c *TokenServiceClient_Revoke_Call) RunAndReturn(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error)) *TokenServiceClient_Revoke_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/mocks/tokenizer.go b/auth/mocks/tokenizer.go new file mode 100644 index 0000000000..11ea445a0b --- /dev/null +++ b/auth/mocks/tokenizer.go @@ -0,0 +1,226 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/auth" + mock "github.com/stretchr/testify/mock" +) + +// NewTokenizer creates a new instance of Tokenizer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenizer(t interface { + mock.TestingT + Cleanup(func()) +}) *Tokenizer { + mock := &Tokenizer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Tokenizer is an autogenerated mock type for the Tokenizer type +type Tokenizer struct { + mock.Mock +} + +type Tokenizer_Expecter struct { + mock *mock.Mock +} + +func (_m *Tokenizer) EXPECT() *Tokenizer_Expecter { + return &Tokenizer_Expecter{mock: &_m.Mock} +} + +// Issue provides a mock function for the type Tokenizer +func (_mock *Tokenizer) Issue(key auth.Key) (string, error) { + ret := _mock.Called(key) + + if len(ret) == 0 { + panic("no return value specified for Issue") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(auth.Key) (string, error)); ok { + return returnFunc(key) + } + if returnFunc, ok := ret.Get(0).(func(auth.Key) string); ok { + r0 = returnFunc(key) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(auth.Key) error); ok { + r1 = returnFunc(key) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Tokenizer_Issue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Issue' +type Tokenizer_Issue_Call struct { + *mock.Call +} + +// Issue is a helper method to define mock.On call +// - key auth.Key +func (_e *Tokenizer_Expecter) Issue(key interface{}) *Tokenizer_Issue_Call { + return &Tokenizer_Issue_Call{Call: _e.mock.On("Issue", key)} +} + +func (_c *Tokenizer_Issue_Call) Run(run func(key auth.Key)) *Tokenizer_Issue_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 auth.Key + if args[0] != nil { + arg0 = args[0].(auth.Key) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Tokenizer_Issue_Call) Return(token string, err error) *Tokenizer_Issue_Call { + _c.Call.Return(token, err) + return _c +} + +func (_c *Tokenizer_Issue_Call) RunAndReturn(run func(key auth.Key) (string, error)) *Tokenizer_Issue_Call { + _c.Call.Return(run) + return _c +} + +// Parse provides a mock function for the type Tokenizer +func (_mock *Tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) { + ret := _mock.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for Parse") + } + + var r0 auth.Key + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (auth.Key, error)); ok { + return returnFunc(ctx, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) auth.Key); ok { + r0 = returnFunc(ctx, token) + } else { + r0 = ret.Get(0).(auth.Key) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, token) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Tokenizer_Parse_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Parse' +type Tokenizer_Parse_Call struct { + *mock.Call +} + +// Parse is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *Tokenizer_Expecter) Parse(ctx interface{}, token interface{}) *Tokenizer_Parse_Call { + return &Tokenizer_Parse_Call{Call: _e.mock.On("Parse", ctx, token)} +} + +func (_c *Tokenizer_Parse_Call) Run(run func(ctx context.Context, token string)) *Tokenizer_Parse_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Tokenizer_Parse_Call) Return(key auth.Key, err error) *Tokenizer_Parse_Call { + _c.Call.Return(key, err) + return _c +} + +func (_c *Tokenizer_Parse_Call) RunAndReturn(run func(ctx context.Context, token string) (auth.Key, error)) *Tokenizer_Parse_Call { + _c.Call.Return(run) + return _c +} + +// Revoke provides a mock function for the type Tokenizer +func (_mock *Tokenizer) Revoke(ctx context.Context, token string) error { + ret := _mock.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for Revoke") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, token) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Tokenizer_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke' +type Tokenizer_Revoke_Call struct { + *mock.Call +} + +// Revoke is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *Tokenizer_Expecter) Revoke(ctx interface{}, token interface{}) *Tokenizer_Revoke_Call { + return &Tokenizer_Revoke_Call{Call: _e.mock.On("Revoke", ctx, token)} +} + +func (_c *Tokenizer_Revoke_Call) Run(run func(ctx context.Context, token string)) *Tokenizer_Revoke_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Tokenizer_Revoke_Call) Return(err error) *Tokenizer_Revoke_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Tokenizer_Revoke_Call) RunAndReturn(run func(ctx context.Context, token string) error) *Tokenizer_Revoke_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/mocks/tokens_cache.go b/auth/mocks/tokens_cache.go new file mode 100644 index 0000000000..eaaa9a114e --- /dev/null +++ b/auth/mocks/tokens_cache.go @@ -0,0 +1,225 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + mock "github.com/stretchr/testify/mock" +) + +// NewTokensCache creates a new instance of TokensCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokensCache(t interface { + mock.TestingT + Cleanup(func()) +}) *TokensCache { + mock := &TokensCache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// TokensCache is an autogenerated mock type for the TokensCache type +type TokensCache struct { + mock.Mock +} + +type TokensCache_Expecter struct { + mock *mock.Mock +} + +func (_m *TokensCache) EXPECT() *TokensCache_Expecter { + return &TokensCache_Expecter{mock: &_m.Mock} +} + +// Contains provides a mock function for the type TokensCache +func (_mock *TokensCache) Contains(ctx context.Context, key string, value string) bool { + ret := _mock.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) bool); ok { + r0 = returnFunc(ctx, key, value) + } else { + r0 = ret.Get(0).(bool) + } + return r0 +} + +// TokensCache_Contains_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Contains' +type TokensCache_Contains_Call struct { + *mock.Call +} + +// Contains is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - value string +func (_e *TokensCache_Expecter) Contains(ctx interface{}, key interface{}, value interface{}) *TokensCache_Contains_Call { + return &TokensCache_Contains_Call{Call: _e.mock.On("Contains", ctx, key, value)} +} + +func (_c *TokensCache_Contains_Call) Run(run func(ctx context.Context, key string, value string)) *TokensCache_Contains_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *TokensCache_Contains_Call) Return(b bool) *TokensCache_Contains_Call { + _c.Call.Return(b) + return _c +} + +func (_c *TokensCache_Contains_Call) RunAndReturn(run func(ctx context.Context, key string, value string) bool) *TokensCache_Contains_Call { + _c.Call.Return(run) + return _c +} + +// Remove provides a mock function for the type TokensCache +func (_mock *TokensCache) Remove(ctx context.Context, key string) error { + ret := _mock.Called(ctx, key) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, key) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// TokensCache_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove' +type TokensCache_Remove_Call struct { + *mock.Call +} + +// Remove is a helper method to define mock.On call +// - ctx context.Context +// - key string +func (_e *TokensCache_Expecter) Remove(ctx interface{}, key interface{}) *TokensCache_Remove_Call { + return &TokensCache_Remove_Call{Call: _e.mock.On("Remove", ctx, key)} +} + +func (_c *TokensCache_Remove_Call) Run(run func(ctx context.Context, key string)) *TokensCache_Remove_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokensCache_Remove_Call) Return(err error) *TokensCache_Remove_Call { + _c.Call.Return(err) + return _c +} + +func (_c *TokensCache_Remove_Call) RunAndReturn(run func(ctx context.Context, key string) error) *TokensCache_Remove_Call { + _c.Call.Return(run) + return _c +} + +// Save provides a mock function for the type TokensCache +func (_mock *TokensCache) Save(ctx context.Context, key string, value string) error { + ret := _mock.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, key, value) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// TokensCache_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save' +type TokensCache_Save_Call struct { + *mock.Call +} + +// Save is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - value string +func (_e *TokensCache_Expecter) Save(ctx interface{}, key interface{}, value interface{}) *TokensCache_Save_Call { + return &TokensCache_Save_Call{Call: _e.mock.On("Save", ctx, key, value)} +} + +func (_c *TokensCache_Save_Call) Run(run func(ctx context.Context, key string, value string)) *TokensCache_Save_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *TokensCache_Save_Call) Return(err error) *TokensCache_Save_Call { + _c.Call.Return(err) + return _c +} + +func (_c *TokensCache_Save_Call) RunAndReturn(run func(ctx context.Context, key string, value string) error) *TokensCache_Save_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/mocks/tokens_repository.go b/auth/mocks/tokens_repository.go new file mode 100644 index 0000000000..b1a57ca80e --- /dev/null +++ b/auth/mocks/tokens_repository.go @@ -0,0 +1,156 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + mock "github.com/stretchr/testify/mock" +) + +// NewTokensRepository creates a new instance of TokensRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokensRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *TokensRepository { + mock := &TokensRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// TokensRepository is an autogenerated mock type for the TokensRepository type +type TokensRepository struct { + mock.Mock +} + +type TokensRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *TokensRepository) EXPECT() *TokensRepository_Expecter { + return &TokensRepository_Expecter{mock: &_m.Mock} +} + +// Contains provides a mock function for the type TokensRepository +func (_mock *TokensRepository) Contains(ctx context.Context, id string) bool { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Get(0).(bool) + } + return r0 +} + +// TokensRepository_Contains_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Contains' +type TokensRepository_Contains_Call struct { + *mock.Call +} + +// Contains is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *TokensRepository_Expecter) Contains(ctx interface{}, id interface{}) *TokensRepository_Contains_Call { + return &TokensRepository_Contains_Call{Call: _e.mock.On("Contains", ctx, id)} +} + +func (_c *TokensRepository_Contains_Call) Run(run func(ctx context.Context, id string)) *TokensRepository_Contains_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokensRepository_Contains_Call) Return(ok bool) *TokensRepository_Contains_Call { + _c.Call.Return(ok) + return _c +} + +func (_c *TokensRepository_Contains_Call) RunAndReturn(run func(ctx context.Context, id string) bool) *TokensRepository_Contains_Call { + _c.Call.Return(run) + return _c +} + +// Save provides a mock function for the type TokensRepository +func (_mock *TokensRepository) Save(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// TokensRepository_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save' +type TokensRepository_Save_Call struct { + *mock.Call +} + +// Save is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *TokensRepository_Expecter) Save(ctx interface{}, id interface{}) *TokensRepository_Save_Call { + return &TokensRepository_Save_Call{Call: _e.mock.On("Save", ctx, id)} +} + +func (_c *TokensRepository_Save_Call) Run(run func(ctx context.Context, id string)) *TokensRepository_Save_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokensRepository_Save_Call) Return(err error) *TokensRepository_Save_Call { + _c.Call.Return(err) + return _c +} + +func (_c *TokensRepository_Save_Call) RunAndReturn(run func(ctx context.Context, id string) error) *TokensRepository_Save_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/postgres/init.go b/auth/postgres/init.go index dfdec3a748..0351d91695 100644 --- a/auth/postgres/init.go +++ b/auth/postgres/init.go @@ -125,6 +125,17 @@ func Migration() *migrate.MemoryMigrationSource { `ALTER TABLE pats ALTER COLUMN last_used_at TYPE TIMESTAMP;`, }, }, + { + Id: "auth_7", + Up: []string{ + `CREATE TABLE IF NOT EXISTS revoked_tokens ( + id VARCHAR(36) PRIMARY KEY + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS revoked_tokens`, + }, + }, }, } } diff --git a/auth/postgres/token.go b/auth/postgres/token.go new file mode 100644 index 0000000000..83714fe9ce --- /dev/null +++ b/auth/postgres/token.go @@ -0,0 +1,65 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/postgres" +) + +var _ auth.TokensRepository = (*tokenRepo)(nil) + +type tokenRepo struct { + db postgres.Database +} + +// NewTokensRepository instantiates a PostgreSQL implementation of tokens repository. +func NewTokensRepository(db postgres.Database) auth.TokensRepository { + return &tokenRepo{ + db: db, + } +} + +func (repo *tokenRepo) Save(ctx context.Context, id string) error { + q := `INSERT INTO revoked_tokens (id) VALUES ($1);` + + result, err := repo.db.ExecContext(ctx, q, id) + if err != nil { + return postgres.HandleError(repoerr.ErrCreateEntity, err) + } + rows, err := result.RowsAffected() + if err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + if rows == 0 { + return repoerr.ErrCreateEntity + } + + return nil +} + +func (repo *tokenRepo) Contains(ctx context.Context, id string) bool { + q := `SELECT * FROM revoked_tokens WHERE id = $1;` + + rows, err := repo.db.QueryContext(ctx, q, id) + if err != nil { + return false + } + defer rows.Close() + + if rows.Next() { + id := "" + if err = rows.Scan(&id); err != nil { + return false + } + + return true + } + + return false +} diff --git a/auth/service.go b/auth/service.go index e153fb8202..e641704f86 100644 --- a/auth/service.go +++ b/auth/service.go @@ -19,7 +19,6 @@ import ( const ( recoveryDuration = 5 * time.Minute - defLimit = 100 randStr = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&&*|+-=" patPrefix = "pat" patSecretSeparator = "_" @@ -67,6 +66,9 @@ type Authn interface { // Issue issues a new Key, returning its token value alongside. Issue(ctx context.Context, token string, key Key) (Token, error) + // RevokeToken revokes the token. + RevokeToken(ctx context.Context, token string) error + // Revoke removes the Key with the provided id that is // issued by the user identified by the provided key. Revoke(ctx context.Context, token, id string) error @@ -140,8 +142,12 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err } } +func (svc service) RevokeToken(ctx context.Context, token string) error { + return svc.tokenizer.Revoke(ctx, token) +} + func (svc service) Revoke(ctx context.Context, token, id string) error { - issuerID, _, err := svc.authenticate(token) + issuerID, _, err := svc.authenticate(ctx, token) if err != nil { return errors.Wrap(errRevoke, err) } @@ -152,7 +158,7 @@ func (svc service) Revoke(ctx context.Context, token, id string) error { } func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, error) { - issuerID, _, err := svc.authenticate(token) + issuerID, _, err := svc.authenticate(ctx, token) if err != nil { return Key{}, errors.Wrap(errRetrieve, err) } @@ -165,7 +171,7 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro } func (svc service) Identify(ctx context.Context, token string) (Key, error) { - key, err := svc.tokenizer.Parse(token) + key, err := svc.tokenizer.Parse(ctx, token) if errors.Contains(err, ErrExpiry) { err = svc.keys.Remove(ctx, key.Issuer, key.ID) return Key{}, errors.Wrap(svcerr.ErrAuthentication, errors.Wrap(ErrKeyExpired, err)) @@ -308,7 +314,7 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) { } func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token, error) { - k, err := svc.tokenizer.Parse(token) + k, err := svc.tokenizer.Parse(ctx, token) if err != nil { return Token{}, errors.Wrap(errRetrieve, err) } @@ -385,7 +391,7 @@ func (svc service) getUserRole(ctx context.Context, userID string) (role Role) { } func (svc service) userKey(ctx context.Context, token string, key Key) (Token, error) { - id, sub, err := svc.authenticate(token) + id, sub, err := svc.authenticate(ctx, token) if err != nil { return Token{}, errors.Wrap(errIssueUser, err) } @@ -416,8 +422,8 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e return Token{AccessToken: tkn}, nil } -func (svc service) authenticate(token string) (string, string, error) { - key, err := svc.tokenizer.Parse(token) +func (svc service) authenticate(ctx context.Context, token string) (string, string, error) { + key, err := svc.tokenizer.Parse(ctx, token) if err != nil { return "", "", errors.Wrap(svcerr.ErrAuthentication, err) } diff --git a/auth/service_test.go b/auth/service_test.go index e1bf10b93a..c83dc7cdb1 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -25,12 +25,6 @@ import ( const ( secret = "secret" - email = "test@example.com" - id = "testID" - groupName = "smqx" - description = "Description" - memberRelation = "member" - authoritiesObj = "authorities" loginDuration = 30 * time.Minute refreshDuration = 24 * time.Hour invalidDuration = 7 * 24 * time.Hour @@ -53,6 +47,8 @@ var ( patsrepo *mocks.PATSRepository cache *mocks.Cache hasher *mocks.Hasher + trepo *mocks.TokensRepository + tcache *mocks.TokensCache ) func newService() (auth.Service, string) { @@ -63,8 +59,10 @@ func newService() (auth.Service, string) { patsrepo = new(mocks.PATSRepository) hasher = new(mocks.Hasher) idProvider := uuid.NewMock() + trepo = new(mocks.TokensRepository) + tcache = new(mocks.TokensCache) - t := jwt.New([]byte(secret)) + t := jwt.New([]byte(secret), trepo, tcache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -80,7 +78,7 @@ func newService() (auth.Service, string) { func TestIssue(t *testing.T) { svc, accessToken := newService() - n := jwt.New([]byte(secret)) + n := jwt.New([]byte(secret), trepo, tcache) apikey := auth.Key{ IssuedAt: time.Now(), @@ -313,9 +311,13 @@ func TestIssue(t *testing.T) { Object: policies.SuperMQObject, ObjectType: policies.PlatformType, }).Return(tc.roleCheckErr) + cacheCall := tcache.On("Contains", mock.Anything, "", "").Return(false) + repoCall := trepo.On("Contains", mock.Anything, "").Return(false) _, err := svc.Issue(context.Background(), tc.token, tc.key) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) policyCall.Unset() + cacheCall.Unset() + repoCall.Unset() } } @@ -469,7 +471,7 @@ func TestIdentify(t *testing.T) { repoCall.Unset() repoCall1.Unset() - te := jwt.New([]byte(secret)) + te := jwt.New([]byte(secret), trepo, tcache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -538,11 +540,15 @@ func TestIdentify(t *testing.T) { for _, tc := range cases { repoCall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(auth.Key{}, tc.err) repoCall1 := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(tc.err) + cacheCall := tcache.On("Contains", mock.Anything, "", "").Return(false) + repoCall2 := trepo.On("Contains", mock.Anything, "").Return(false) idt, err := svc.Identify(context.Background(), tc.key) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.subject, idt.Subject, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.subject, idt)) repoCall.Unset() repoCall1.Unset() + cacheCall.Unset() + repoCall2.Unset() } } diff --git a/auth/tokenizer.go b/auth/tokenizer.go index 1aaed7df4f..0d94d00d00 100644 --- a/auth/tokenizer.go +++ b/auth/tokenizer.go @@ -3,11 +3,38 @@ package auth +import "context" + // Tokenizer specifies API for encoding and decoding between string and Key. type Tokenizer interface { // Issue converts API Key to its string representation. Issue(key Key) (token string, err error) // Parse extracts API Key data from string token. - Parse(token string) (key Key, err error) + Parse(ctx context.Context, token string) (key Key, err error) + + // Revoke revokes the token. + Revoke(ctx context.Context, token string) error +} + +// TokensCache represents a cache repository. It exposes functionalities +// through `auth` to perform caching. +type TokensCache interface { + // Save saves the key-value pair in the cache. + Save(ctx context.Context, key, value string) error + + // Contains checks if the key-value pair exists in the cache. + Contains(ctx context.Context, key, value string) bool + + // Remove removes the key from the cache. + Remove(ctx context.Context, key string) error +} + +// TokensRepository specifies token persistence API. +type TokensRepository interface { + // Save persists the token. + Save(ctx context.Context, id string) (err error) + + // Contains checks if token with provided ID exists. + Contains(ctx context.Context, id string) (ok bool) } diff --git a/cmd/auth/main.go b/cmd/auth/main.go index adffc74971..b5081f4428 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -226,18 +226,20 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch } func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration) (auth.Service, error) { - cache := cache.NewPatsCache(cacheClient, keyDuration) + patsCache := cache.NewPatsCache(cacheClient, keyDuration) + tokensCache := cache.NewTokensCache(cacheClient, keyDuration) database := pgclient.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) - patsRepo := apostgres.NewPatRepo(database, cache) + patsRepo := apostgres.NewPatRepo(database, patsCache) + tokensRepo := apostgres.NewTokensRepository(database) hasher := hasher.New() idProvider := uuid.New() pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger) pService := spicedb.NewPolicyService(spicedbClient, logger) - t := jwt.New([]byte(cfg.SecretKey)) + t := jwt.New([]byte(cfg.SecretKey), tokensRepo, tokensCache) svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) svc = middleware.NewLogging(svc, logger) diff --git a/internal/proto/token/v1/token.proto b/internal/proto/token/v1/token.proto index 10c066511f..de8e61129e 100644 --- a/internal/proto/token/v1/token.proto +++ b/internal/proto/token/v1/token.proto @@ -9,6 +9,7 @@ option go_package = "github.com/absmach/supermq/api/grpc/token/v1"; service TokenService { rpc Issue(IssueReq) returns (Token) {} rpc Refresh(RefreshReq) returns (Token) {} + rpc Revoke(RevokeReq) returns (RevokeRes) {} } message IssueReq { @@ -23,6 +24,10 @@ message RefreshReq { bool verified = 2; } +message RevokeReq { + string token = 1; +} + // If a token is not carrying any information itself, the type // field can be used to determine how to validate the token. // Also, different tokens can be encoded in different ways. @@ -31,3 +36,7 @@ message Token { optional string refresh_token = 2; string access_type = 3; } + +message RevokeRes{ + +} \ No newline at end of file diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index 6e8228406f..8c001c75bb 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -60,6 +60,9 @@ packages: KeyRepository: PATS: PATSRepository: + TokensRepository: + TokensCache: + Tokenizer: Service: github.com/absmach/supermq/channels: interfaces: