From 65f43a5880a4eebfa14801e63b4265d966f5aadc Mon Sep 17 00:00:00 2001 From: le-xuan-quynh Date: Tue, 3 May 2022 20:37:24 +0700 Subject: [PATCH] add get access token API --- pkg/authorization/endpoints/endpoints.go | 23 +++++++++++++++ pkg/authorization/middleware/middleware.go | 15 +++++----- pkg/authorization/reqresponse.go | 12 ++++++++ pkg/authorization/service.go | 1 + pkg/authorization/transport/http.go | 25 ++++++++++++++++ pkg/authorization/users-service.go | 34 ++++++++++++++++++++++ 6 files changed, 103 insertions(+), 7 deletions(-) diff --git a/pkg/authorization/endpoints/endpoints.go b/pkg/authorization/endpoints/endpoints.go index 1237800..453f8e5 100644 --- a/pkg/authorization/endpoints/endpoints.go +++ b/pkg/authorization/endpoints/endpoints.go @@ -24,6 +24,7 @@ type Set struct { UpdatePasswordEndpoint endpoint.Endpoint GetForgetPasswordCodeEndpoint endpoint.Endpoint ResetPasswordEndpoint endpoint.Endpoint + GenerateAccessTokenEndpoint endpoint.Endpoint } func NewEndpointSet(svc authorization.Service, @@ -80,6 +81,11 @@ func NewEndpointSet(svc authorization.Service, resetPasswordEndpoint = middleware.RateLimitRequest(tb, logger)(resetPasswordEndpoint) resetPasswordEndpoint = middleware.ValidateParamRequest(validator, logger)(resetPasswordEndpoint) + generateAccessTokenEndpoint := MakeGenerateAccessTokenEndpoint(svc) + generateAccessTokenEndpoint = middleware.RateLimitRequest(tb, logger)(generateAccessTokenEndpoint) + generateAccessTokenEndpoint = middleware.ValidateParamRequest(validator, logger)(generateAccessTokenEndpoint) + generateAccessTokenEndpoint = middleware.ValidateRefreshToken(auth, r, logger)(generateAccessTokenEndpoint) + return Set{ HealthCheckEndpoint: healthCheckEndpoint, RegisterEndpoint: registerEndpoint, @@ -92,6 +98,7 @@ func NewEndpointSet(svc authorization.Service, UpdatePasswordEndpoint: updatePasswordEndpoint, GetForgetPasswordCodeEndpoint: getForgetPasswordCodeEndpoint, ResetPasswordEndpoint: resetPasswordEndpoint, + GenerateAccessTokenEndpoint: generateAccessTokenEndpoint, } } @@ -285,3 +292,19 @@ func MakeCreateNewPasswordWithCodeEndpoint(svc authorization.Service) endpoint.E return "successfully updated password.", nil } } + +// MakeGenerateAccessTokenEndpoint returns an endpoint that invokes GenerateAccessToken on the service. +func MakeGenerateAccessTokenEndpoint(svc authorization.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req, ok := request.(authorization.GenerateAccessTokenRequest) + if !ok { + cusErr := utils.NewErrorResponse(utils.BadRequest) + return nil, cusErr + } + token, err := svc.GenerateAccessToken(ctx, &req) + if err != nil { + return nil, err + } + return token, nil + } +} diff --git a/pkg/authorization/middleware/middleware.go b/pkg/authorization/middleware/middleware.go index d229352..887993a 100644 --- a/pkg/authorization/middleware/middleware.go +++ b/pkg/authorization/middleware/middleware.go @@ -20,22 +20,23 @@ type UserIDKey struct{} func ValidateRefreshToken(auth Authentication, r database.UserRepository, logger hclog.Logger) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { - authErr := authorizedRefreshToken(ctx, auth, r, logger, request) + userID, authErr := authorizedRefreshToken(ctx, auth, r, logger, request) if authErr != nil { return nil, authErr } + ctx = context.WithValue(ctx, UserIDKey{}, userID) return next(ctx, request) } } } // authorizedRefreshToken validates the refresh token. -func authorizedRefreshToken(ctx context.Context, auth Authentication, r database.UserRepository, logger hclog.Logger, request interface{}) error { +func authorizedRefreshToken(ctx context.Context, auth Authentication, r database.UserRepository, logger hclog.Logger, request interface{}) (string, error) { token, err := extractValue(request, "refresh_token") if err != nil { logger.Error("extract value token failed", "err", err) cusErr := utils.NewErrorResponse(utils.BadRequest) - return cusErr + return "", cusErr } logger.Debug("token present in header", token) @@ -43,7 +44,7 @@ func authorizedRefreshToken(ctx context.Context, auth Authentication, r database if err != nil { logger.Error("token validation failed", "error", err) cusErr := utils.NewErrorResponse(utils.ValidationTokenFailure) - return cusErr + return "", cusErr } logger.Debug("refresh token validated") @@ -51,16 +52,16 @@ func authorizedRefreshToken(ctx context.Context, auth Authentication, r database if err != nil { logger.Error("You're not authorized. Please try again latter.", err) cusErr := utils.NewErrorResponse(utils.ValidationTokenFailure) - return cusErr + return "", cusErr } actualCustomKey := auth.GenerateCustomKey(user.ID, user.TokenHash) if customKey != actualCustomKey { logger.Debug("wrong token: authentication failed") cusErr := utils.NewErrorResponse(utils.Unauthorized) - return cusErr + return "", cusErr } - return nil + return userID, nil } func extractValue(request interface{}, key string) (string, error) { diff --git a/pkg/authorization/reqresponse.go b/pkg/authorization/reqresponse.go index bbf9187..f197fd5 100644 --- a/pkg/authorization/reqresponse.go +++ b/pkg/authorization/reqresponse.go @@ -120,3 +120,15 @@ type CreateNewPasswordWithCodeRequest struct { Email string `json:"email" validate:"required,email"` NewPassword string `json:"new_password" validate:"required"` } + +// GenerateAccessTokenRequest is used to generate access token +type GenerateAccessTokenRequest struct { + RefreshToken string `json:"refresh_token" validate:"required"` +} + +// GenerateAccessResponse is the response for generate access token +type GenerateAccessResponse struct { + RefreshToken string `json:"refresh_token,omitempty"` + AccessToken string `json:"access_token,omitempty"` + Username string `json:"username,omitempty"` +} diff --git a/pkg/authorization/service.go b/pkg/authorization/service.go index f7fbc71..449c36a 100644 --- a/pkg/authorization/service.go +++ b/pkg/authorization/service.go @@ -15,4 +15,5 @@ type Service interface { UpdatePassword(ctx context.Context, request *UpdatePasswordRequest) (string, error) GetForgetPasswordCode(ctx context.Context, email string) error ResetPassword(ctx context.Context, request *CreateNewPasswordWithCodeRequest) error + GenerateAccessToken(ctx context.Context, request *GenerateAccessTokenRequest) (interface{}, error) } diff --git a/pkg/authorization/transport/http.go b/pkg/authorization/transport/http.go index 484d05a..24d70f7 100644 --- a/pkg/authorization/transport/http.go +++ b/pkg/authorization/transport/http.go @@ -82,6 +82,13 @@ func NewHTTPHandler(ep endpoints.Set) http.Handler { encodeResponse, options..., )) + m.Handle("/generate-access-token", httptransport.NewServer( + ep.GenerateAccessTokenEndpoint, + decodeHTTPGenerateAccessTokenRequest, + encodeResponse, + options..., + )) + mux := http.NewServeMux() mux.Handle("/api/v1/", http.StripPrefix("/api/v1", m)) return mux @@ -298,6 +305,24 @@ func decodeHTTPResetPasswordRequest(_ context.Context, r *http.Request) (interfa } } +// decodeHTTPGenerateAccessTokenRequest decode request +func decodeHTTPGenerateAccessTokenRequest(_ context.Context, r *http.Request) (interface{}, error) { + if r.Method == "POST" { + var req authorization.GenerateAccessTokenRequest + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + return nil, utils.NewErrorResponse(utils.BadRequest) + } + if req.RefreshToken == "" { + return nil, utils.NewErrorResponse(utils.RefreshTokenRequired) + } + return req, nil + } else { + cusErr := utils.NewErrorResponse(utils.MethodNotAllowed) + return nil, cusErr + } +} + func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error { w.Header().Set("Content-Type", "application/json; charset=utf-8") diff --git a/pkg/authorization/users-service.go b/pkg/authorization/users-service.go index 9f2a94d..deaaf86 100644 --- a/pkg/authorization/users-service.go +++ b/pkg/authorization/users-service.go @@ -775,3 +775,37 @@ func (s *userService) ResetPassword(ctx context.Context, request *CreateNewPassw s.logger.Info("Password changed", "userID", user.ID) return nil } + +// GenerateAccessToken generate access token +func (s *userService) GenerateAccessToken(ctx context.Context, request *GenerateAccessTokenRequest) (interface{}, error) { + userID, ok := ctx.Value(middleware.UserIDKey{}).(string) + if !ok { + s.logger.Error("Error getting userID from context") + cusErr := utils.NewErrorResponse(utils.InternalServerError) + return cusErr.Error(), cusErr + } + user, err := s.repo.GetUserByID(ctx, userID) + if err != nil { + s.logger.Error("Cannot get user", "error", err) + cusErr := utils.NewErrorResponse(utils.InternalServerError) + return cusErr.Error(), cusErr + } + // Check if user is banned + if user.Banned { + s.logger.Error("User is banned", "error", err) + cusErr := utils.NewErrorResponse(utils.Forbidden) + return cusErr.Error(), cusErr + } + accessToken, err := s.auth.GenerateAccessToken(user) + if err != nil { + s.logger.Error("unable to generate access token", "error", err) + cusErr := utils.NewErrorResponse(utils.InternalServerError) + return cusErr.Error(), cusErr + } + + s.logger.Debug("Successfully generated new access token") + return GenerateAccessResponse{ + AccessToken: accessToken, + Username: user.Username, + }, nil +}