From 3033b491a081001ef5cc0a2a7dd912a9f08b11b2 Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:21:31 +1000 Subject: [PATCH 1/9] We can provide the full request. --- server/handler.go | 126 +++--- server/server.go | 1059 +++++++++++++++++++++++---------------------- 2 files changed, 593 insertions(+), 592 deletions(-) diff --git a/server/handler.go b/server/handler.go index e0d5d32..df66137 100755 --- a/server/handler.go +++ b/server/handler.go @@ -1,63 +1,63 @@ -package server - -import ( - "net/http" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -type ( - // ClientInfoHandler get client info from request - ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) - - // ClientAuthorizedHandler check the client allows to use this authorization grant type - ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) - - // ClientScopeHandler check the client allows to use scope - ClientScopeHandler func(clientID, scope string) (allowed bool, err error) - - // UserAuthorizationHandler get user id from request authorization - UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) - - // PasswordAuthorizationHandler get user id from username and password - PasswordAuthorizationHandler func(username, password string) (userID string, err error) - - // RefreshingScopeHandler check the scope of the refreshing token - RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error) - - // ResponseErrorHandler response error handing - ResponseErrorHandler func(re *errors.Response) - - // InternalErrorHandler internal error handing - InternalErrorHandler func(err error) (re *errors.Response) - - // AuthorizeScopeHandler set the authorized scope - AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) - - // AccessTokenExpHandler set expiration date for the access token - AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) - - // ExtensionFieldsHandler in response to the access token with the extension of the field - ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) -) - -// ClientFormHandler get client data from form -func ClientFormHandler(r *http.Request) (string, string, error) { - clientID := r.Form.Get("client_id") - if clientID == "" { - return "", "", errors.ErrInvalidClient - } - clientSecret := r.Form.Get("client_secret") - return clientID, clientSecret, nil -} - -// ClientBasicHandler get client data from basic authorization -func ClientBasicHandler(r *http.Request) (string, string, error) { - username, password, ok := r.BasicAuth() - if !ok { - return "", "", errors.ErrInvalidClient - } - return username, password, nil -} +package server + +import ( + "net/http" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +type ( + // ClientInfoHandler get client info from request + ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) + + // ClientAuthorizedHandler check the client allows to use this authorization grant type + ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) + + // ClientScopeHandler check the client allows to use scope + ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) + + // UserAuthorizationHandler get user id from request authorization + UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) + + // PasswordAuthorizationHandler get user id from username and password + PasswordAuthorizationHandler func(username, password string) (userID string, err error) + + // RefreshingScopeHandler check the scope of the refreshing token + RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) + + // ResponseErrorHandler response error handing + ResponseErrorHandler func(re *errors.Response) + + // InternalErrorHandler internal error handing + InternalErrorHandler func(err error) (re *errors.Response) + + // AuthorizeScopeHandler set the authorized scope + AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) + + // AccessTokenExpHandler set expiration date for the access token + AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) + + // ExtensionFieldsHandler in response to the access token with the extension of the field + ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) +) + +// ClientFormHandler get client data from form +func ClientFormHandler(r *http.Request) (string, string, error) { + clientID := r.Form.Get("client_id") + if clientID == "" { + return "", "", errors.ErrInvalidClient + } + clientSecret := r.Form.Get("client_secret") + return clientID, clientSecret, nil +} + +// ClientBasicHandler get client data from basic authorization +func ClientBasicHandler(r *http.Request) (string, string, error) { + username, password, ok := r.BasicAuth() + if !ok { + return "", "", errors.ErrInvalidClient + } + return username, password, nil +} diff --git a/server/server.go b/server/server.go index ca1cd94..6887431 100755 --- a/server/server.go +++ b/server/server.go @@ -1,529 +1,530 @@ -package server - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -// NewDefaultServer create a default authorization server -func NewDefaultServer(manager oauth2.Manager) *Server { - return NewServer(NewConfig(), manager) -} - -// NewServer create authorization server -func NewServer(cfg *Config, manager oauth2.Manager) *Server { - srv := &Server{ - Config: cfg, - Manager: manager, - } - - // default handler - srv.ClientInfoHandler = ClientBasicHandler - - srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { - return "", errors.ErrAccessDenied - } - - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { - return "", errors.ErrAccessDenied - } - return srv -} - -// Server Provide authorization server -type Server struct { - Config *Config - Manager oauth2.Manager - ClientInfoHandler ClientInfoHandler - ClientAuthorizedHandler ClientAuthorizedHandler - ClientScopeHandler ClientScopeHandler - UserAuthorizationHandler UserAuthorizationHandler - PasswordAuthorizationHandler PasswordAuthorizationHandler - RefreshingScopeHandler RefreshingScopeHandler - ResponseErrorHandler ResponseErrorHandler - InternalErrorHandler InternalErrorHandler - ExtensionFieldsHandler ExtensionFieldsHandler - AccessTokenExpHandler AccessTokenExpHandler - AuthorizeScopeHandler AuthorizeScopeHandler -} - -func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { - if req == nil { - return err - } - data, _, _ := s.GetErrorData(err) - return s.redirect(w, req, data) -} - -func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { - uri, err := s.GetRedirectURI(req, data) - if err != nil { - return err - } - - w.Header().Set("Location", uri) - w.WriteHeader(302) - return nil -} - -func (s *Server) tokenError(w http.ResponseWriter, err error) error { - data, statusCode, header := s.GetErrorData(err) - return s.token(w, data, header, statusCode) -} - -func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { - w.Header().Set("Content-Type", "application/json;charset=UTF-8") - w.Header().Set("Cache-Control", "no-store") - w.Header().Set("Pragma", "no-cache") - - for key := range header { - w.Header().Set(key, header.Get(key)) - } - - status := http.StatusOK - if len(statusCode) > 0 && statusCode[0] > 0 { - status = statusCode[0] - } - - w.WriteHeader(status) - return json.NewEncoder(w).Encode(data) -} - -// GetRedirectURI get redirect uri -func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { - u, err := url.Parse(req.RedirectURI) - if err != nil { - return "", err - } - - q := u.Query() - if req.State != "" { - q.Set("state", req.State) - } - - for k, v := range data { - q.Set(k, fmt.Sprint(v)) - } - - switch req.ResponseType { - case oauth2.Code: - u.RawQuery = q.Encode() - case oauth2.Token: - u.RawQuery = "" - fragment, err := url.QueryUnescape(q.Encode()) - if err != nil { - return "", err - } - u.Fragment = fragment - } - - return u.String(), nil -} - -// CheckResponseType check allows response type -func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.Config.AllowedResponseTypes { - if art == rt { - return true - } - } - return false -} - -// ValidationAuthorizeRequest the authorization request validation -func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { - redirectURI := r.FormValue("redirect_uri") - clientID := r.FormValue("client_id") - if !(r.Method == "GET" || r.Method == "POST") || - clientID == "" { - return nil, errors.ErrInvalidRequest - } - - resType := oauth2.ResponseType(r.FormValue("response_type")) - if resType.String() == "" { - return nil, errors.ErrUnsupportedResponseType - } else if allowed := s.CheckResponseType(resType); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - req := &AuthorizeRequest{ - RedirectURI: redirectURI, - ResponseType: resType, - ClientID: clientID, - State: r.FormValue("state"), - Scope: r.FormValue("scope"), - Request: r, - } - return req, nil -} - -// GetAuthorizeToken get authorization token(code) -func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { - // check the client allows the grant type - if fn := s.ClientAuthorizedHandler; fn != nil { - gt := oauth2.AuthorizationCode - if req.ResponseType == oauth2.Token { - gt = oauth2.Implicit - } - - allowed, err := fn(req.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - // check the client allows the authorized scope - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(req.ClientID, req.Scope) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, - } - return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) -} - -// GetAuthorizeData get authorization response data -func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { - if rt == oauth2.Code { - return map[string]interface{}{ - "code": ti.GetCode(), - } - } - return s.GetTokenData(ti) -} - -// HandleAuthorizeRequest the authorization request handling -func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - req, err := s.ValidationAuthorizeRequest(r) - if err != nil { - return s.redirectError(w, req, err) - } - - // user authorization - userID, err := s.UserAuthorizationHandler(w, r) - if err != nil { - return s.redirectError(w, req, err) - } else if userID == "" { - return nil - } - req.UserID = userID - - // specify the scope of authorization - if fn := s.AuthorizeScopeHandler; fn != nil { - scope, err := fn(w, r) - if err != nil { - return err - } else if scope != "" { - req.Scope = scope - } - } - - // specify the expiration time of access token - if fn := s.AccessTokenExpHandler; fn != nil { - exp, err := fn(w, r) - if err != nil { - return err - } - req.AccessTokenExp = exp - } - - ti, err := s.GetAuthorizeToken(ctx, req) - if err != nil { - return s.redirectError(w, req, err) - } - - // If the redirect URI is empty, the default domain provided by the client is used. - if req.RedirectURI == "" { - client, err := s.Manager.GetClient(ctx, req.ClientID) - if err != nil { - return err - } - req.RedirectURI = client.GetDomain() - } - - return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) -} - -// ValidationTokenRequest the token request validation -func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { - if v := r.Method; !(v == "POST" || - (s.Config.AllowGetAccessRequest && v == "GET")) { - return "", nil, errors.ErrInvalidRequest - } - - gt := oauth2.GrantType(r.FormValue("grant_type")) - if gt.String() == "" { - return "", nil, errors.ErrUnsupportedGrantType - } - - clientID, clientSecret, err := s.ClientInfoHandler(r) - if err != nil { - return "", nil, err - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: clientID, - ClientSecret: clientSecret, - Request: r, - } - - switch gt { - case oauth2.AuthorizationCode: - tgr.RedirectURI = r.FormValue("redirect_uri") - tgr.Code = r.FormValue("code") - if tgr.RedirectURI == "" || - tgr.Code == "" { - return "", nil, errors.ErrInvalidRequest - } - case oauth2.PasswordCredentials: - tgr.Scope = r.FormValue("scope") - username, password := r.FormValue("username"), r.FormValue("password") - if username == "" || password == "" { - return "", nil, errors.ErrInvalidRequest - } - - userID, err := s.PasswordAuthorizationHandler(username, password) - if err != nil { - return "", nil, err - } else if userID == "" { - return "", nil, errors.ErrInvalidGrant - } - tgr.UserID = userID - case oauth2.ClientCredentials: - tgr.Scope = r.FormValue("scope") - case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") - tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest - } - } - return gt, tgr, nil -} - -// CheckGrantType check allows grant type -func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { - for _, agt := range s.Config.AllowedGrantTypes { - if agt == gt { - return true - } - } - return false -} - -// GetAccessToken access token -func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - if allowed := s.CheckGrantType(gt); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - if fn := s.ClientAuthorizedHandler; fn != nil { - allowed, err := fn(tgr.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - switch gt { - case oauth2.AuthorizationCode: - ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) - if err != nil { - switch err { - case errors.ErrInvalidAuthorizeCode: - return nil, errors.ErrInvalidGrant - case errors.ErrInvalidClient: - return nil, errors.ErrInvalidClient - default: - return nil, err - } - } - return ti, nil - case oauth2.PasswordCredentials, oauth2.ClientCredentials: - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr.ClientID, tgr.Scope) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - return s.Manager.GenerateAccessToken(ctx, gt, tgr) - case oauth2.Refreshing: - // check scope - if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil { - rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - - allowed, err := scopeFn(scope, rti.GetScope()) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - ti, err := s.Manager.RefreshAccessToken(ctx, tgr) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - return ti, nil - } - - return nil, errors.ErrUnsupportedGrantType -} - -// GetTokenData token data -func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { - data := map[string]interface{}{ - "access_token": ti.GetAccess(), - "token_type": s.Config.TokenType, - "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), - } - - if scope := ti.GetScope(); scope != "" { - data["scope"] = scope - } - - if refresh := ti.GetRefresh(); refresh != "" { - data["refresh_token"] = refresh - } - - if fn := s.ExtensionFieldsHandler; fn != nil { - ext := fn(ti) - for k, v := range ext { - if _, ok := data[k]; ok { - continue - } - data[k] = v - } - } - return data -} - -// HandleTokenRequest token request handling -func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - gt, tgr, err := s.ValidationTokenRequest(r) - if err != nil { - return s.tokenError(w, err) - } - - ti, err := s.GetAccessToken(ctx, gt, tgr) - if err != nil { - return s.tokenError(w, err) - } - - return s.token(w, s.GetTokenData(ti), nil) -} - -// GetErrorData get error response data -func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { - var re errors.Response - if v, ok := errors.Descriptions[err]; ok { - re.Error = err - re.Description = v - re.StatusCode = errors.StatusCodes[err] - } else { - if fn := s.InternalErrorHandler; fn != nil { - if v := fn(err); v != nil { - re = *v - } - } - - if re.Error == nil { - re.Error = errors.ErrServerError - re.Description = errors.Descriptions[errors.ErrServerError] - re.StatusCode = errors.StatusCodes[errors.ErrServerError] - } - } - - if fn := s.ResponseErrorHandler; fn != nil { - fn(&re) - } - - data := make(map[string]interface{}) - if err := re.Error; err != nil { - data["error"] = err.Error() - } - - if v := re.ErrorCode; v != 0 { - data["error_code"] = v - } - - if v := re.Description; v != "" { - data["error_description"] = v - } - - if v := re.URI; v != "" { - data["error_uri"] = v - } - - statusCode := http.StatusInternalServerError - if v := re.StatusCode; v > 0 { - statusCode = v - } - - return data, statusCode, re.Header -} - -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - -// ValidationBearerToken validation the bearer tokens -// https://tools.ietf.org/html/rfc6750 -func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { - ctx := r.Context() - - accessToken, ok := s.BearerAuth(r) - if !ok { - return nil, errors.ErrInvalidAccessToken - } - - return s.Manager.LoadAccessToken(ctx, accessToken) -} +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +// NewDefaultServer create a default authorization server +func NewDefaultServer(manager oauth2.Manager) *Server { + return NewServer(NewConfig(), manager) +} + +// NewServer create authorization server +func NewServer(cfg *Config, manager oauth2.Manager) *Server { + srv := &Server{ + Config: cfg, + Manager: manager, + } + + // default handler + srv.ClientInfoHandler = ClientBasicHandler + + srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { + return "", errors.ErrAccessDenied + } + + srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { + return "", errors.ErrAccessDenied + } + return srv +} + +// Server Provide authorization server +type Server struct { + Config *Config + Manager oauth2.Manager + ClientInfoHandler ClientInfoHandler + ClientAuthorizedHandler ClientAuthorizedHandler + ClientScopeHandler ClientScopeHandler + UserAuthorizationHandler UserAuthorizationHandler + PasswordAuthorizationHandler PasswordAuthorizationHandler + RefreshingScopeHandler RefreshingScopeHandler + ResponseErrorHandler ResponseErrorHandler + InternalErrorHandler InternalErrorHandler + ExtensionFieldsHandler ExtensionFieldsHandler + AccessTokenExpHandler AccessTokenExpHandler + AuthorizeScopeHandler AuthorizeScopeHandler +} + +func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { + if req == nil { + return err + } + data, _, _ := s.GetErrorData(err) + return s.redirect(w, req, data) +} + +func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { + uri, err := s.GetRedirectURI(req, data) + if err != nil { + return err + } + + w.Header().Set("Location", uri) + w.WriteHeader(302) + return nil +} + +func (s *Server) tokenError(w http.ResponseWriter, err error) error { + data, statusCode, header := s.GetErrorData(err) + return s.token(w, data, header, statusCode) +} + +func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { + w.Header().Set("Content-Type", "application/json;charset=UTF-8") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + + for key := range header { + w.Header().Set(key, header.Get(key)) + } + + status := http.StatusOK + if len(statusCode) > 0 && statusCode[0] > 0 { + status = statusCode[0] + } + + w.WriteHeader(status) + return json.NewEncoder(w).Encode(data) +} + +// GetRedirectURI get redirect uri +func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { + u, err := url.Parse(req.RedirectURI) + if err != nil { + return "", err + } + + q := u.Query() + if req.State != "" { + q.Set("state", req.State) + } + + for k, v := range data { + q.Set(k, fmt.Sprint(v)) + } + + switch req.ResponseType { + case oauth2.Code: + u.RawQuery = q.Encode() + case oauth2.Token: + u.RawQuery = "" + fragment, err := url.QueryUnescape(q.Encode()) + if err != nil { + return "", err + } + u.Fragment = fragment + } + + return u.String(), nil +} + +// CheckResponseType check allows response type +func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { + for _, art := range s.Config.AllowedResponseTypes { + if art == rt { + return true + } + } + return false +} + +// ValidationAuthorizeRequest the authorization request validation +func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + if !(r.Method == "GET" || r.Method == "POST") || + clientID == "" { + return nil, errors.ErrInvalidRequest + } + + resType := oauth2.ResponseType(r.FormValue("response_type")) + if resType.String() == "" { + return nil, errors.ErrUnsupportedResponseType + } else if allowed := s.CheckResponseType(resType); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + req := &AuthorizeRequest{ + RedirectURI: redirectURI, + ResponseType: resType, + ClientID: clientID, + State: r.FormValue("state"), + Scope: r.FormValue("scope"), + Request: r, + } + return req, nil +} + +// GetAuthorizeToken get authorization token(code) +func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { + // check the client allows the grant type + if fn := s.ClientAuthorizedHandler; fn != nil { + gt := oauth2.AuthorizationCode + if req.ResponseType == oauth2.Token { + gt = oauth2.Implicit + } + + allowed, err := fn(req.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + } + + // check the client allows the authorized scope + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) +} + +// GetAuthorizeData get authorization response data +func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { + if rt == oauth2.Code { + return map[string]interface{}{ + "code": ti.GetCode(), + } + } + return s.GetTokenData(ti) +} + +// HandleAuthorizeRequest the authorization request handling +func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + req, err := s.ValidationAuthorizeRequest(r) + if err != nil { + return s.redirectError(w, req, err) + } + + // user authorization + userID, err := s.UserAuthorizationHandler(w, r) + if err != nil { + return s.redirectError(w, req, err) + } else if userID == "" { + return nil + } + req.UserID = userID + + // specify the scope of authorization + if fn := s.AuthorizeScopeHandler; fn != nil { + scope, err := fn(w, r) + if err != nil { + return err + } else if scope != "" { + req.Scope = scope + } + } + + // specify the expiration time of access token + if fn := s.AccessTokenExpHandler; fn != nil { + exp, err := fn(w, r) + if err != nil { + return err + } + req.AccessTokenExp = exp + } + + ti, err := s.GetAuthorizeToken(ctx, req) + if err != nil { + return s.redirectError(w, req, err) + } + + // If the redirect URI is empty, the default domain provided by the client is used. + if req.RedirectURI == "" { + client, err := s.Manager.GetClient(ctx, req.ClientID) + if err != nil { + return err + } + req.RedirectURI = client.GetDomain() + } + + return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) +} + +// ValidationTokenRequest the token request validation +func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { + if v := r.Method; !(v == "POST" || + (s.Config.AllowGetAccessRequest && v == "GET")) { + return "", nil, errors.ErrInvalidRequest + } + + gt := oauth2.GrantType(r.FormValue("grant_type")) + if gt.String() == "" { + return "", nil, errors.ErrUnsupportedGrantType + } + + clientID, clientSecret, err := s.ClientInfoHandler(r) + if err != nil { + return "", nil, err + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + Request: r, + } + + switch gt { + case oauth2.AuthorizationCode: + tgr.RedirectURI = r.FormValue("redirect_uri") + tgr.Code = r.FormValue("code") + if tgr.RedirectURI == "" || + tgr.Code == "" { + return "", nil, errors.ErrInvalidRequest + } + case oauth2.PasswordCredentials: + tgr.Scope = r.FormValue("scope") + username, password := r.FormValue("username"), r.FormValue("password") + if username == "" || password == "" { + return "", nil, errors.ErrInvalidRequest + } + + userID, err := s.PasswordAuthorizationHandler(username, password) + if err != nil { + return "", nil, err + } else if userID == "" { + return "", nil, errors.ErrInvalidGrant + } + tgr.UserID = userID + case oauth2.ClientCredentials: + tgr.Scope = r.FormValue("scope") + case oauth2.Refreshing: + tgr.Refresh = r.FormValue("refresh_token") + tgr.Scope = r.FormValue("scope") + if tgr.Refresh == "" { + return "", nil, errors.ErrInvalidRequest + } + } + return gt, tgr, nil +} + +// CheckGrantType check allows grant type +func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { + for _, agt := range s.Config.AllowedGrantTypes { + if agt == gt { + return true + } + } + return false +} + +// GetAccessToken access token +func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + if allowed := s.CheckGrantType(gt); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + if fn := s.ClientAuthorizedHandler; fn != nil { + allowed, err := fn(tgr.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + switch gt { + case oauth2.AuthorizationCode: + ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) + if err != nil { + switch err { + case errors.ErrInvalidAuthorizeCode: + return nil, errors.ErrInvalidGrant + case errors.ErrInvalidClient: + return nil, errors.ErrInvalidClient + default: + return nil, err + } + } + return ti, nil + case oauth2.PasswordCredentials, oauth2.ClientCredentials: + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + return s.Manager.GenerateAccessToken(ctx, gt, tgr) + case oauth2.Refreshing: + // check scope + if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { + rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + + allowed, err := scopeFn(tgr, rti.GetScope()) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + ti, err := s.Manager.RefreshAccessToken(ctx, tgr) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + return ti, nil + } + + return nil, errors.ErrUnsupportedGrantType +} + +// GetTokenData token data +func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { + data := map[string]interface{}{ + "access_token": ti.GetAccess(), + "token_type": s.Config.TokenType, + "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), + } + + if scope := ti.GetScope(); scope != "" { + data["scope"] = scope + } + + if refresh := ti.GetRefresh(); refresh != "" { + data["refresh_token"] = refresh + } + + if fn := s.ExtensionFieldsHandler; fn != nil { + ext := fn(ti) + for k, v := range ext { + if _, ok := data[k]; ok { + continue + } + data[k] = v + } + } + return data +} + +// HandleTokenRequest token request handling +func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + gt, tgr, err := s.ValidationTokenRequest(r) + if err != nil { + return s.tokenError(w, err) + } + + ti, err := s.GetAccessToken(ctx, gt, tgr) + if err != nil { + return s.tokenError(w, err) + } + + return s.token(w, s.GetTokenData(ti), nil) +} + +// GetErrorData get error response data +func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { + var re errors.Response + if v, ok := errors.Descriptions[err]; ok { + re.Error = err + re.Description = v + re.StatusCode = errors.StatusCodes[err] + } else { + if fn := s.InternalErrorHandler; fn != nil { + if v := fn(err); v != nil { + re = *v + } + } + + if re.Error == nil { + re.Error = errors.ErrServerError + re.Description = errors.Descriptions[errors.ErrServerError] + re.StatusCode = errors.StatusCodes[errors.ErrServerError] + } + } + + if fn := s.ResponseErrorHandler; fn != nil { + fn(&re) + } + + data := make(map[string]interface{}) + if err := re.Error; err != nil { + data["error"] = err.Error() + } + + if v := re.ErrorCode; v != 0 { + data["error_code"] = v + } + + if v := re.Description; v != "" { + data["error_description"] = v + } + + if v := re.URI; v != "" { + data["error_uri"] = v + } + + statusCode := http.StatusInternalServerError + if v := re.StatusCode; v > 0 { + statusCode = v + } + + return data, statusCode, re.Header +} + +// BearerAuth parse bearer token +func (s *Server) BearerAuth(r *http.Request) (string, bool) { + auth := r.Header.Get("Authorization") + prefix := "Bearer " + token := "" + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +// ValidationBearerToken validation the bearer tokens +// https://tools.ietf.org/html/rfc6750 +func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { + ctx := r.Context() + + accessToken, ok := s.BearerAuth(r) + if !ok { + return nil, errors.ErrInvalidAccessToken + } + + return s.Manager.LoadAccessToken(ctx, accessToken) +} From 2f083c345942f316419efe87fae5e9e92d754a0b Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:28:09 +1000 Subject: [PATCH 2/9] git ignore added. --- .editorconfig | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .editorconfig diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..6c376a8 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,2 @@ +[*.go] +end_of_line = lf \ No newline at end of file From 3ef29f48b6ff83d5a98e3a51452ca94c0540ffa9 Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:33:17 +1000 Subject: [PATCH 3/9] Updated --- .editorconfig | 2 +- server/handler.go | 64 +----- server/server.go | 531 +--------------------------------------------- 3 files changed, 3 insertions(+), 594 deletions(-) diff --git a/.editorconfig b/.editorconfig index 6c376a8..d8549da 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,2 +1,2 @@ [*.go] -end_of_line = lf \ No newline at end of file +end_of_line = cr \ No newline at end of file diff --git a/server/handler.go b/server/handler.go index df66137..e55e448 100755 --- a/server/handler.go +++ b/server/handler.go @@ -1,63 +1 @@ -package server - -import ( - "net/http" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -type ( - // ClientInfoHandler get client info from request - ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) - - // ClientAuthorizedHandler check the client allows to use this authorization grant type - ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) - - // ClientScopeHandler check the client allows to use scope - ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) - - // UserAuthorizationHandler get user id from request authorization - UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) - - // PasswordAuthorizationHandler get user id from username and password - PasswordAuthorizationHandler func(username, password string) (userID string, err error) - - // RefreshingScopeHandler check the scope of the refreshing token - RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) - - // ResponseErrorHandler response error handing - ResponseErrorHandler func(re *errors.Response) - - // InternalErrorHandler internal error handing - InternalErrorHandler func(err error) (re *errors.Response) - - // AuthorizeScopeHandler set the authorized scope - AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) - - // AccessTokenExpHandler set expiration date for the access token - AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) - - // ExtensionFieldsHandler in response to the access token with the extension of the field - ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) -) - -// ClientFormHandler get client data from form -func ClientFormHandler(r *http.Request) (string, string, error) { - clientID := r.Form.Get("client_id") - if clientID == "" { - return "", "", errors.ErrInvalidClient - } - clientSecret := r.Form.Get("client_secret") - return clientID, clientSecret, nil -} - -// ClientBasicHandler get client data from basic authorization -func ClientBasicHandler(r *http.Request) (string, string, error) { - username, password, ok := r.BasicAuth() - if !ok { - return "", "", errors.ErrInvalidClient - } - return username, password, nil -} +package server import ( "net/http" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" ) type ( // ClientInfoHandler get client info from request ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) // ClientAuthorizedHandler check the client allows to use this authorization grant type ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) // ClientScopeHandler check the client allows to use scope ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) // UserAuthorizationHandler get user id from request authorization UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) // PasswordAuthorizationHandler get user id from username and password PasswordAuthorizationHandler func(username, password string) (userID string, err error) // RefreshingScopeHandler check the scope of the refreshing token RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) // ResponseErrorHandler response error handing ResponseErrorHandler func(re *errors.Response) // InternalErrorHandler internal error handing InternalErrorHandler func(err error) (re *errors.Response) // AuthorizeScopeHandler set the authorized scope AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) // AccessTokenExpHandler set expiration date for the access token AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) // ExtensionFieldsHandler in response to the access token with the extension of the field ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) ) // ClientFormHandler get client data from form func ClientFormHandler(r *http.Request) (string, string, error) { clientID := r.Form.Get("client_id") if clientID == "" { return "", "", errors.ErrInvalidClient } clientSecret := r.Form.Get("client_secret") return clientID, clientSecret, nil } // ClientBasicHandler get client data from basic authorization func ClientBasicHandler(r *http.Request) (string, string, error) { username, password, ok := r.BasicAuth() if !ok { return "", "", errors.ErrInvalidClient } return username, password, nil } \ No newline at end of file diff --git a/server/server.go b/server/server.go index 6887431..cf8ec7b 100755 --- a/server/server.go +++ b/server/server.go @@ -1,530 +1 @@ -package server - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -// NewDefaultServer create a default authorization server -func NewDefaultServer(manager oauth2.Manager) *Server { - return NewServer(NewConfig(), manager) -} - -// NewServer create authorization server -func NewServer(cfg *Config, manager oauth2.Manager) *Server { - srv := &Server{ - Config: cfg, - Manager: manager, - } - - // default handler - srv.ClientInfoHandler = ClientBasicHandler - - srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { - return "", errors.ErrAccessDenied - } - - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { - return "", errors.ErrAccessDenied - } - return srv -} - -// Server Provide authorization server -type Server struct { - Config *Config - Manager oauth2.Manager - ClientInfoHandler ClientInfoHandler - ClientAuthorizedHandler ClientAuthorizedHandler - ClientScopeHandler ClientScopeHandler - UserAuthorizationHandler UserAuthorizationHandler - PasswordAuthorizationHandler PasswordAuthorizationHandler - RefreshingScopeHandler RefreshingScopeHandler - ResponseErrorHandler ResponseErrorHandler - InternalErrorHandler InternalErrorHandler - ExtensionFieldsHandler ExtensionFieldsHandler - AccessTokenExpHandler AccessTokenExpHandler - AuthorizeScopeHandler AuthorizeScopeHandler -} - -func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { - if req == nil { - return err - } - data, _, _ := s.GetErrorData(err) - return s.redirect(w, req, data) -} - -func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { - uri, err := s.GetRedirectURI(req, data) - if err != nil { - return err - } - - w.Header().Set("Location", uri) - w.WriteHeader(302) - return nil -} - -func (s *Server) tokenError(w http.ResponseWriter, err error) error { - data, statusCode, header := s.GetErrorData(err) - return s.token(w, data, header, statusCode) -} - -func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { - w.Header().Set("Content-Type", "application/json;charset=UTF-8") - w.Header().Set("Cache-Control", "no-store") - w.Header().Set("Pragma", "no-cache") - - for key := range header { - w.Header().Set(key, header.Get(key)) - } - - status := http.StatusOK - if len(statusCode) > 0 && statusCode[0] > 0 { - status = statusCode[0] - } - - w.WriteHeader(status) - return json.NewEncoder(w).Encode(data) -} - -// GetRedirectURI get redirect uri -func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { - u, err := url.Parse(req.RedirectURI) - if err != nil { - return "", err - } - - q := u.Query() - if req.State != "" { - q.Set("state", req.State) - } - - for k, v := range data { - q.Set(k, fmt.Sprint(v)) - } - - switch req.ResponseType { - case oauth2.Code: - u.RawQuery = q.Encode() - case oauth2.Token: - u.RawQuery = "" - fragment, err := url.QueryUnescape(q.Encode()) - if err != nil { - return "", err - } - u.Fragment = fragment - } - - return u.String(), nil -} - -// CheckResponseType check allows response type -func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.Config.AllowedResponseTypes { - if art == rt { - return true - } - } - return false -} - -// ValidationAuthorizeRequest the authorization request validation -func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { - redirectURI := r.FormValue("redirect_uri") - clientID := r.FormValue("client_id") - if !(r.Method == "GET" || r.Method == "POST") || - clientID == "" { - return nil, errors.ErrInvalidRequest - } - - resType := oauth2.ResponseType(r.FormValue("response_type")) - if resType.String() == "" { - return nil, errors.ErrUnsupportedResponseType - } else if allowed := s.CheckResponseType(resType); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - req := &AuthorizeRequest{ - RedirectURI: redirectURI, - ResponseType: resType, - ClientID: clientID, - State: r.FormValue("state"), - Scope: r.FormValue("scope"), - Request: r, - } - return req, nil -} - -// GetAuthorizeToken get authorization token(code) -func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { - // check the client allows the grant type - if fn := s.ClientAuthorizedHandler; fn != nil { - gt := oauth2.AuthorizationCode - if req.ResponseType == oauth2.Token { - gt = oauth2.Implicit - } - - allowed, err := fn(req.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, - } - - // check the client allows the authorized scope - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) -} - -// GetAuthorizeData get authorization response data -func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { - if rt == oauth2.Code { - return map[string]interface{}{ - "code": ti.GetCode(), - } - } - return s.GetTokenData(ti) -} - -// HandleAuthorizeRequest the authorization request handling -func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - req, err := s.ValidationAuthorizeRequest(r) - if err != nil { - return s.redirectError(w, req, err) - } - - // user authorization - userID, err := s.UserAuthorizationHandler(w, r) - if err != nil { - return s.redirectError(w, req, err) - } else if userID == "" { - return nil - } - req.UserID = userID - - // specify the scope of authorization - if fn := s.AuthorizeScopeHandler; fn != nil { - scope, err := fn(w, r) - if err != nil { - return err - } else if scope != "" { - req.Scope = scope - } - } - - // specify the expiration time of access token - if fn := s.AccessTokenExpHandler; fn != nil { - exp, err := fn(w, r) - if err != nil { - return err - } - req.AccessTokenExp = exp - } - - ti, err := s.GetAuthorizeToken(ctx, req) - if err != nil { - return s.redirectError(w, req, err) - } - - // If the redirect URI is empty, the default domain provided by the client is used. - if req.RedirectURI == "" { - client, err := s.Manager.GetClient(ctx, req.ClientID) - if err != nil { - return err - } - req.RedirectURI = client.GetDomain() - } - - return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) -} - -// ValidationTokenRequest the token request validation -func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { - if v := r.Method; !(v == "POST" || - (s.Config.AllowGetAccessRequest && v == "GET")) { - return "", nil, errors.ErrInvalidRequest - } - - gt := oauth2.GrantType(r.FormValue("grant_type")) - if gt.String() == "" { - return "", nil, errors.ErrUnsupportedGrantType - } - - clientID, clientSecret, err := s.ClientInfoHandler(r) - if err != nil { - return "", nil, err - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: clientID, - ClientSecret: clientSecret, - Request: r, - } - - switch gt { - case oauth2.AuthorizationCode: - tgr.RedirectURI = r.FormValue("redirect_uri") - tgr.Code = r.FormValue("code") - if tgr.RedirectURI == "" || - tgr.Code == "" { - return "", nil, errors.ErrInvalidRequest - } - case oauth2.PasswordCredentials: - tgr.Scope = r.FormValue("scope") - username, password := r.FormValue("username"), r.FormValue("password") - if username == "" || password == "" { - return "", nil, errors.ErrInvalidRequest - } - - userID, err := s.PasswordAuthorizationHandler(username, password) - if err != nil { - return "", nil, err - } else if userID == "" { - return "", nil, errors.ErrInvalidGrant - } - tgr.UserID = userID - case oauth2.ClientCredentials: - tgr.Scope = r.FormValue("scope") - case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") - tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest - } - } - return gt, tgr, nil -} - -// CheckGrantType check allows grant type -func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { - for _, agt := range s.Config.AllowedGrantTypes { - if agt == gt { - return true - } - } - return false -} - -// GetAccessToken access token -func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - if allowed := s.CheckGrantType(gt); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - if fn := s.ClientAuthorizedHandler; fn != nil { - allowed, err := fn(tgr.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - switch gt { - case oauth2.AuthorizationCode: - ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) - if err != nil { - switch err { - case errors.ErrInvalidAuthorizeCode: - return nil, errors.ErrInvalidGrant - case errors.ErrInvalidClient: - return nil, errors.ErrInvalidClient - default: - return nil, err - } - } - return ti, nil - case oauth2.PasswordCredentials, oauth2.ClientCredentials: - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - return s.Manager.GenerateAccessToken(ctx, gt, tgr) - case oauth2.Refreshing: - // check scope - if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { - rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - - allowed, err := scopeFn(tgr, rti.GetScope()) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - ti, err := s.Manager.RefreshAccessToken(ctx, tgr) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - return ti, nil - } - - return nil, errors.ErrUnsupportedGrantType -} - -// GetTokenData token data -func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { - data := map[string]interface{}{ - "access_token": ti.GetAccess(), - "token_type": s.Config.TokenType, - "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), - } - - if scope := ti.GetScope(); scope != "" { - data["scope"] = scope - } - - if refresh := ti.GetRefresh(); refresh != "" { - data["refresh_token"] = refresh - } - - if fn := s.ExtensionFieldsHandler; fn != nil { - ext := fn(ti) - for k, v := range ext { - if _, ok := data[k]; ok { - continue - } - data[k] = v - } - } - return data -} - -// HandleTokenRequest token request handling -func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - gt, tgr, err := s.ValidationTokenRequest(r) - if err != nil { - return s.tokenError(w, err) - } - - ti, err := s.GetAccessToken(ctx, gt, tgr) - if err != nil { - return s.tokenError(w, err) - } - - return s.token(w, s.GetTokenData(ti), nil) -} - -// GetErrorData get error response data -func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { - var re errors.Response - if v, ok := errors.Descriptions[err]; ok { - re.Error = err - re.Description = v - re.StatusCode = errors.StatusCodes[err] - } else { - if fn := s.InternalErrorHandler; fn != nil { - if v := fn(err); v != nil { - re = *v - } - } - - if re.Error == nil { - re.Error = errors.ErrServerError - re.Description = errors.Descriptions[errors.ErrServerError] - re.StatusCode = errors.StatusCodes[errors.ErrServerError] - } - } - - if fn := s.ResponseErrorHandler; fn != nil { - fn(&re) - } - - data := make(map[string]interface{}) - if err := re.Error; err != nil { - data["error"] = err.Error() - } - - if v := re.ErrorCode; v != 0 { - data["error_code"] = v - } - - if v := re.Description; v != "" { - data["error_description"] = v - } - - if v := re.URI; v != "" { - data["error_uri"] = v - } - - statusCode := http.StatusInternalServerError - if v := re.StatusCode; v > 0 { - statusCode = v - } - - return data, statusCode, re.Header -} - -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - -// ValidationBearerToken validation the bearer tokens -// https://tools.ietf.org/html/rfc6750 -func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { - ctx := r.Context() - - accessToken, ok := s.BearerAuth(r) - if !ok { - return nil, errors.ErrInvalidAccessToken - } - - return s.Manager.LoadAccessToken(ctx, accessToken) -} +package server import ( "context" "encoding/json" "fmt" "net/http" "net/url" "strings" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" ) // NewDefaultServer create a default authorization server func NewDefaultServer(manager oauth2.Manager) *Server { return NewServer(NewConfig(), manager) } // NewServer create authorization server func NewServer(cfg *Config, manager oauth2.Manager) *Server { srv := &Server{ Config: cfg, Manager: manager, } // default handler srv.ClientInfoHandler = ClientBasicHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv } // Server Provide authorization server type Server struct { Config *Config Manager oauth2.Manager ClientInfoHandler ClientInfoHandler ClientAuthorizedHandler ClientAuthorizedHandler ClientScopeHandler ClientScopeHandler UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { uri, err := s.GetRedirectURI(req, data) if err != nil { return err } w.Header().Set("Location", uri) w.WriteHeader(302) return nil } func (s *Server) tokenError(w http.ResponseWriter, err error) error { data, statusCode, header := s.GetErrorData(err) return s.token(w, data, header, statusCode) } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") for key := range header { w.Header().Set(key, header.Get(key)) } status := http.StatusOK if len(statusCode) > 0 && statusCode[0] > 0 { status = statusCode[0] } w.WriteHeader(status) return json.NewEncoder(w).Encode(data) } // GetRedirectURI get redirect uri func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { u, err := url.Parse(req.RedirectURI) if err != nil { return "", err } q := u.Query() if req.State != "" { q.Set("state", req.State) } for k, v := range data { q.Set(k, fmt.Sprint(v)) } switch req.ResponseType { case oauth2.Code: u.RawQuery = q.Encode() case oauth2.Token: u.RawQuery = "" fragment, err := url.QueryUnescape(q.Encode()) if err != nil { return "", err } u.Fragment = fragment } return u.String(), nil } // CheckResponseType check allows response type func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { for _, art := range s.Config.AllowedResponseTypes { if art == rt { return true } } return false } // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") clientID := r.FormValue("client_id") if !(r.Method == "GET" || r.Method == "POST") || clientID == "" { return nil, errors.ErrInvalidRequest } resType := oauth2.ResponseType(r.FormValue("response_type")) if resType.String() == "" { return nil, errors.ErrUnsupportedResponseType } else if allowed := s.CheckResponseType(resType); !allowed { return nil, errors.ErrUnauthorizedClient } req := &AuthorizeRequest{ RedirectURI: redirectURI, ResponseType: resType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), Request: r, } return req, nil } // GetAuthorizeToken get authorization token(code) func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { // check the client allows the grant type if fn := s.ClientAuthorizedHandler; fn != nil { gt := oauth2.AuthorizationCode if req.ResponseType == oauth2.Token { gt = oauth2.Implicit } allowed, err := fn(req.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } tgr := &oauth2.TokenGenerateRequest{ ClientID: req.ClientID, UserID: req.UserID, RedirectURI: req.RedirectURI, Scope: req.Scope, AccessTokenExp: req.AccessTokenExp, Request: req.Request, } // check the client allows the authorized scope if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) } // GetAuthorizeData get authorization response data func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { if rt == oauth2.Code { return map[string]interface{}{ "code": ti.GetCode(), } } return s.GetTokenData(ti) } // HandleAuthorizeRequest the authorization request handling func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() req, err := s.ValidationAuthorizeRequest(r) if err != nil { return s.redirectError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { return s.redirectError(w, req, err) } else if userID == "" { return nil } req.UserID = userID // specify the scope of authorization if fn := s.AuthorizeScopeHandler; fn != nil { scope, err := fn(w, r) if err != nil { return err } else if scope != "" { req.Scope = scope } } // specify the expiration time of access token if fn := s.AccessTokenExpHandler; fn != nil { exp, err := fn(w, r) if err != nil { return err } req.AccessTokenExp = exp } ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { return s.redirectError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. if req.RedirectURI == "" { client, err := s.Manager.GetClient(ctx, req.ClientID) if err != nil { return err } req.RedirectURI = client.GetDomain() } return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) } // ValidationTokenRequest the token request validation func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { if v := r.Method; !(v == "POST" || (s.Config.AllowGetAccessRequest && v == "GET")) { return "", nil, errors.ErrInvalidRequest } gt := oauth2.GrantType(r.FormValue("grant_type")) if gt.String() == "" { return "", nil, errors.ErrUnsupportedGrantType } clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return "", nil, err } tgr := &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, Request: r, } switch gt { case oauth2.AuthorizationCode: tgr.RedirectURI = r.FormValue("redirect_uri") tgr.Code = r.FormValue("code") if tgr.RedirectURI == "" || tgr.Code == "" { return "", nil, errors.ErrInvalidRequest } case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") username, password := r.FormValue("username"), r.FormValue("password") if username == "" || password == "" { return "", nil, errors.ErrInvalidRequest } userID, err := s.PasswordAuthorizationHandler(username, password) if err != nil { return "", nil, err } else if userID == "" { return "", nil, errors.ErrInvalidGrant } tgr.UserID = userID case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") case oauth2.Refreshing: tgr.Refresh = r.FormValue("refresh_token") tgr.Scope = r.FormValue("scope") if tgr.Refresh == "" { return "", nil, errors.ErrInvalidRequest } } return gt, tgr, nil } // CheckGrantType check allows grant type func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { for _, agt := range s.Config.AllowedGrantTypes { if agt == gt { return true } } return false } // GetAccessToken access token func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { if allowed := s.CheckGrantType(gt); !allowed { return nil, errors.ErrUnauthorizedClient } if fn := s.ClientAuthorizedHandler; fn != nil { allowed, err := fn(tgr.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } switch gt { case oauth2.AuthorizationCode: ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) if err != nil { switch err { case errors.ErrInvalidAuthorizeCode: return nil, errors.ErrInvalidGrant case errors.ErrInvalidClient: return nil, errors.ErrInvalidClient default: return nil, err } } return ti, nil case oauth2.PasswordCredentials, oauth2.ClientCredentials: if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAccessToken(ctx, gt, tgr) case oauth2.Refreshing: // check scope if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } allowed, err := scopeFn(tgr, rti.GetScope()) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } ti, err := s.Manager.RefreshAccessToken(ctx, tgr) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } return ti, nil } return nil, errors.ErrUnsupportedGrantType } // GetTokenData token data func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { data := map[string]interface{}{ "access_token": ti.GetAccess(), "token_type": s.Config.TokenType, "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } if scope := ti.GetScope(); scope != "" { data["scope"] = scope } if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } if fn := s.ExtensionFieldsHandler; fn != nil { ext := fn(ti) for k, v := range ext { if _, ok := data[k]; ok { continue } data[k] = v } } return data } // HandleTokenRequest token request handling func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() gt, tgr, err := s.ValidationTokenRequest(r) if err != nil { return s.tokenError(w, err) } ti, err := s.GetAccessToken(ctx, gt, tgr) if err != nil { return s.tokenError(w, err) } return s.token(w, s.GetTokenData(ti), nil) } // GetErrorData get error response data func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { var re errors.Response if v, ok := errors.Descriptions[err]; ok { re.Error = err re.Description = v re.StatusCode = errors.StatusCodes[err] } else { if fn := s.InternalErrorHandler; fn != nil { if v := fn(err); v != nil { re = *v } } if re.Error == nil { re.Error = errors.ErrServerError re.Description = errors.Descriptions[errors.ErrServerError] re.StatusCode = errors.StatusCodes[errors.ErrServerError] } } if fn := s.ResponseErrorHandler; fn != nil { fn(&re) } data := make(map[string]interface{}) if err := re.Error; err != nil { data["error"] = err.Error() } if v := re.ErrorCode; v != 0 { data["error_code"] = v } if v := re.Description; v != "" { data["error_description"] = v } if v := re.URI; v != "" { data["error_uri"] = v } statusCode := http.StatusInternalServerError if v := re.StatusCode; v > 0 { statusCode = v } return data, statusCode, re.Header } // BearerAuth parse bearer token func (s *Server) BearerAuth(r *http.Request) (string, bool) { auth := r.Header.Get("Authorization") prefix := "Bearer " token := "" if auth != "" && strings.HasPrefix(auth, prefix) { token = auth[len(prefix):] } else { token = r.FormValue("access_token") } return token, token != "" } // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() accessToken, ok := s.BearerAuth(r) if !ok { return nil, errors.ErrInvalidAccessToken } return s.Manager.LoadAccessToken(ctx, accessToken) } \ No newline at end of file From 36218409d9834c81115521445221bf4845eceedd Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:33:55 +1000 Subject: [PATCH 4/9] Updated --- .editorconfig | 2 +- server/handler.go | 65 +++++- server/server.go | 532 +++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 596 insertions(+), 3 deletions(-) diff --git a/.editorconfig b/.editorconfig index d8549da..6c376a8 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,2 +1,2 @@ [*.go] -end_of_line = cr \ No newline at end of file +end_of_line = lf \ No newline at end of file diff --git a/server/handler.go b/server/handler.go index e55e448..10e0e96 100755 --- a/server/handler.go +++ b/server/handler.go @@ -1 +1,64 @@ -package server import ( "net/http" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" ) type ( // ClientInfoHandler get client info from request ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) // ClientAuthorizedHandler check the client allows to use this authorization grant type ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) // ClientScopeHandler check the client allows to use scope ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) // UserAuthorizationHandler get user id from request authorization UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) // PasswordAuthorizationHandler get user id from username and password PasswordAuthorizationHandler func(username, password string) (userID string, err error) // RefreshingScopeHandler check the scope of the refreshing token RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) // ResponseErrorHandler response error handing ResponseErrorHandler func(re *errors.Response) // InternalErrorHandler internal error handing InternalErrorHandler func(err error) (re *errors.Response) // AuthorizeScopeHandler set the authorized scope AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) // AccessTokenExpHandler set expiration date for the access token AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) // ExtensionFieldsHandler in response to the access token with the extension of the field ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) ) // ClientFormHandler get client data from form func ClientFormHandler(r *http.Request) (string, string, error) { clientID := r.Form.Get("client_id") if clientID == "" { return "", "", errors.ErrInvalidClient } clientSecret := r.Form.Get("client_secret") return clientID, clientSecret, nil } // ClientBasicHandler get client data from basic authorization func ClientBasicHandler(r *http.Request) (string, string, error) { username, password, ok := r.BasicAuth() if !ok { return "", "", errors.ErrInvalidClient } return username, password, nil } \ No newline at end of file +package server + +import ( + "net/http" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +type ( + // ClientInfoHandler get client info from request + ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) + + // ClientAuthorizedHandler check the client allows to use this authorization grant type + ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) + + // ClientScopeHandler check the client allows to use scope + ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) + + // UserAuthorizationHandler get user id from request authorization + UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) + + // PasswordAuthorizationHandler get user id from username and password + PasswordAuthorizationHandler func(username, password string) (userID string, err error) + + // RefreshingScopeHandler check the scope of the refreshing token + RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) + + // ResponseErrorHandler response error handing + ResponseErrorHandler func(re *errors.Response) + + // InternalErrorHandler internal error handing + InternalErrorHandler func(err error) (re *errors.Response) + + // AuthorizeScopeHandler set the authorized scope + AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) + + // AccessTokenExpHandler set expiration date for the access token + AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) + + // ExtensionFieldsHandler in response to the access token with the extension of the field + ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) +) + + +// ClientFormHandler get client data from form +func ClientFormHandler(r *http.Request) (string, string, error) { + clientID := r.Form.Get("client_id") + if clientID == "" { + return "", "", errors.ErrInvalidClient + } + clientSecret := r.Form.Get("client_secret") + return clientID, clientSecret, nil +} + +// ClientBasicHandler get client data from basic authorization +func ClientBasicHandler(r *http.Request) (string, string, error) { + username, password, ok := r.BasicAuth() + if !ok { + return "", "", errors.ErrInvalidClient + } + return username, password, nil +} diff --git a/server/server.go b/server/server.go index cf8ec7b..61a8616 100755 --- a/server/server.go +++ b/server/server.go @@ -1 +1,531 @@ -package server import ( "context" "encoding/json" "fmt" "net/http" "net/url" "strings" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" ) // NewDefaultServer create a default authorization server func NewDefaultServer(manager oauth2.Manager) *Server { return NewServer(NewConfig(), manager) } // NewServer create authorization server func NewServer(cfg *Config, manager oauth2.Manager) *Server { srv := &Server{ Config: cfg, Manager: manager, } // default handler srv.ClientInfoHandler = ClientBasicHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv } // Server Provide authorization server type Server struct { Config *Config Manager oauth2.Manager ClientInfoHandler ClientInfoHandler ClientAuthorizedHandler ClientAuthorizedHandler ClientScopeHandler ClientScopeHandler UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { uri, err := s.GetRedirectURI(req, data) if err != nil { return err } w.Header().Set("Location", uri) w.WriteHeader(302) return nil } func (s *Server) tokenError(w http.ResponseWriter, err error) error { data, statusCode, header := s.GetErrorData(err) return s.token(w, data, header, statusCode) } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") for key := range header { w.Header().Set(key, header.Get(key)) } status := http.StatusOK if len(statusCode) > 0 && statusCode[0] > 0 { status = statusCode[0] } w.WriteHeader(status) return json.NewEncoder(w).Encode(data) } // GetRedirectURI get redirect uri func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { u, err := url.Parse(req.RedirectURI) if err != nil { return "", err } q := u.Query() if req.State != "" { q.Set("state", req.State) } for k, v := range data { q.Set(k, fmt.Sprint(v)) } switch req.ResponseType { case oauth2.Code: u.RawQuery = q.Encode() case oauth2.Token: u.RawQuery = "" fragment, err := url.QueryUnescape(q.Encode()) if err != nil { return "", err } u.Fragment = fragment } return u.String(), nil } // CheckResponseType check allows response type func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { for _, art := range s.Config.AllowedResponseTypes { if art == rt { return true } } return false } // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") clientID := r.FormValue("client_id") if !(r.Method == "GET" || r.Method == "POST") || clientID == "" { return nil, errors.ErrInvalidRequest } resType := oauth2.ResponseType(r.FormValue("response_type")) if resType.String() == "" { return nil, errors.ErrUnsupportedResponseType } else if allowed := s.CheckResponseType(resType); !allowed { return nil, errors.ErrUnauthorizedClient } req := &AuthorizeRequest{ RedirectURI: redirectURI, ResponseType: resType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), Request: r, } return req, nil } // GetAuthorizeToken get authorization token(code) func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { // check the client allows the grant type if fn := s.ClientAuthorizedHandler; fn != nil { gt := oauth2.AuthorizationCode if req.ResponseType == oauth2.Token { gt = oauth2.Implicit } allowed, err := fn(req.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } tgr := &oauth2.TokenGenerateRequest{ ClientID: req.ClientID, UserID: req.UserID, RedirectURI: req.RedirectURI, Scope: req.Scope, AccessTokenExp: req.AccessTokenExp, Request: req.Request, } // check the client allows the authorized scope if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) } // GetAuthorizeData get authorization response data func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { if rt == oauth2.Code { return map[string]interface{}{ "code": ti.GetCode(), } } return s.GetTokenData(ti) } // HandleAuthorizeRequest the authorization request handling func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() req, err := s.ValidationAuthorizeRequest(r) if err != nil { return s.redirectError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { return s.redirectError(w, req, err) } else if userID == "" { return nil } req.UserID = userID // specify the scope of authorization if fn := s.AuthorizeScopeHandler; fn != nil { scope, err := fn(w, r) if err != nil { return err } else if scope != "" { req.Scope = scope } } // specify the expiration time of access token if fn := s.AccessTokenExpHandler; fn != nil { exp, err := fn(w, r) if err != nil { return err } req.AccessTokenExp = exp } ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { return s.redirectError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. if req.RedirectURI == "" { client, err := s.Manager.GetClient(ctx, req.ClientID) if err != nil { return err } req.RedirectURI = client.GetDomain() } return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) } // ValidationTokenRequest the token request validation func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { if v := r.Method; !(v == "POST" || (s.Config.AllowGetAccessRequest && v == "GET")) { return "", nil, errors.ErrInvalidRequest } gt := oauth2.GrantType(r.FormValue("grant_type")) if gt.String() == "" { return "", nil, errors.ErrUnsupportedGrantType } clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return "", nil, err } tgr := &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, Request: r, } switch gt { case oauth2.AuthorizationCode: tgr.RedirectURI = r.FormValue("redirect_uri") tgr.Code = r.FormValue("code") if tgr.RedirectURI == "" || tgr.Code == "" { return "", nil, errors.ErrInvalidRequest } case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") username, password := r.FormValue("username"), r.FormValue("password") if username == "" || password == "" { return "", nil, errors.ErrInvalidRequest } userID, err := s.PasswordAuthorizationHandler(username, password) if err != nil { return "", nil, err } else if userID == "" { return "", nil, errors.ErrInvalidGrant } tgr.UserID = userID case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") case oauth2.Refreshing: tgr.Refresh = r.FormValue("refresh_token") tgr.Scope = r.FormValue("scope") if tgr.Refresh == "" { return "", nil, errors.ErrInvalidRequest } } return gt, tgr, nil } // CheckGrantType check allows grant type func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { for _, agt := range s.Config.AllowedGrantTypes { if agt == gt { return true } } return false } // GetAccessToken access token func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { if allowed := s.CheckGrantType(gt); !allowed { return nil, errors.ErrUnauthorizedClient } if fn := s.ClientAuthorizedHandler; fn != nil { allowed, err := fn(tgr.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } switch gt { case oauth2.AuthorizationCode: ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) if err != nil { switch err { case errors.ErrInvalidAuthorizeCode: return nil, errors.ErrInvalidGrant case errors.ErrInvalidClient: return nil, errors.ErrInvalidClient default: return nil, err } } return ti, nil case oauth2.PasswordCredentials, oauth2.ClientCredentials: if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAccessToken(ctx, gt, tgr) case oauth2.Refreshing: // check scope if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } allowed, err := scopeFn(tgr, rti.GetScope()) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } ti, err := s.Manager.RefreshAccessToken(ctx, tgr) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } return ti, nil } return nil, errors.ErrUnsupportedGrantType } // GetTokenData token data func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { data := map[string]interface{}{ "access_token": ti.GetAccess(), "token_type": s.Config.TokenType, "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } if scope := ti.GetScope(); scope != "" { data["scope"] = scope } if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } if fn := s.ExtensionFieldsHandler; fn != nil { ext := fn(ti) for k, v := range ext { if _, ok := data[k]; ok { continue } data[k] = v } } return data } // HandleTokenRequest token request handling func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() gt, tgr, err := s.ValidationTokenRequest(r) if err != nil { return s.tokenError(w, err) } ti, err := s.GetAccessToken(ctx, gt, tgr) if err != nil { return s.tokenError(w, err) } return s.token(w, s.GetTokenData(ti), nil) } // GetErrorData get error response data func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { var re errors.Response if v, ok := errors.Descriptions[err]; ok { re.Error = err re.Description = v re.StatusCode = errors.StatusCodes[err] } else { if fn := s.InternalErrorHandler; fn != nil { if v := fn(err); v != nil { re = *v } } if re.Error == nil { re.Error = errors.ErrServerError re.Description = errors.Descriptions[errors.ErrServerError] re.StatusCode = errors.StatusCodes[errors.ErrServerError] } } if fn := s.ResponseErrorHandler; fn != nil { fn(&re) } data := make(map[string]interface{}) if err := re.Error; err != nil { data["error"] = err.Error() } if v := re.ErrorCode; v != 0 { data["error_code"] = v } if v := re.Description; v != "" { data["error_description"] = v } if v := re.URI; v != "" { data["error_uri"] = v } statusCode := http.StatusInternalServerError if v := re.StatusCode; v > 0 { statusCode = v } return data, statusCode, re.Header } // BearerAuth parse bearer token func (s *Server) BearerAuth(r *http.Request) (string, bool) { auth := r.Header.Get("Authorization") prefix := "Bearer " token := "" if auth != "" && strings.HasPrefix(auth, prefix) { token = auth[len(prefix):] } else { token = r.FormValue("access_token") } return token, token != "" } // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() accessToken, ok := s.BearerAuth(r) if !ok { return nil, errors.ErrInvalidAccessToken } return s.Manager.LoadAccessToken(ctx, accessToken) } \ No newline at end of file +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +// NewDefaultServer create a default authorization server +func NewDefaultServer(manager oauth2.Manager) *Server { + return NewServer(NewConfig(), manager) +} + +// NewServer create authorization server +func NewServer(cfg *Config, manager oauth2.Manager) *Server { + srv := &Server{ + Config: cfg, + Manager: manager, + } + + // default handler + srv.ClientInfoHandler = ClientBasicHandler + + srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { + return "", errors.ErrAccessDenied + } + + srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { + return "", errors.ErrAccessDenied + } + return srv +} + +// Server Provide authorization server +type Server struct { + Config *Config + Manager oauth2.Manager + ClientInfoHandler ClientInfoHandler + ClientAuthorizedHandler ClientAuthorizedHandler + ClientScopeHandler ClientScopeHandler + UserAuthorizationHandler UserAuthorizationHandler + PasswordAuthorizationHandler PasswordAuthorizationHandler + RefreshingScopeHandler RefreshingScopeHandler + ResponseErrorHandler ResponseErrorHandler + InternalErrorHandler InternalErrorHandler + ExtensionFieldsHandler ExtensionFieldsHandler + AccessTokenExpHandler AccessTokenExpHandler + AuthorizeScopeHandler AuthorizeScopeHandler +} + + +func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { + if req == nil { + return err + } + data, _, _ := s.GetErrorData(err) + return s.redirect(w, req, data) +} + +func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { + uri, err := s.GetRedirectURI(req, data) + if err != nil { + return err + } + + w.Header().Set("Location", uri) + w.WriteHeader(302) + return nil +} + +func (s *Server) tokenError(w http.ResponseWriter, err error) error { + data, statusCode, header := s.GetErrorData(err) + return s.token(w, data, header, statusCode) +} + +func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { + w.Header().Set("Content-Type", "application/json;charset=UTF-8") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + + for key := range header { + w.Header().Set(key, header.Get(key)) + } + + status := http.StatusOK + if len(statusCode) > 0 && statusCode[0] > 0 { + status = statusCode[0] + } + + w.WriteHeader(status) + return json.NewEncoder(w).Encode(data) +} + +// GetRedirectURI get redirect uri +func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { + u, err := url.Parse(req.RedirectURI) + if err != nil { + return "", err + } + + q := u.Query() + if req.State != "" { + q.Set("state", req.State) + } + + for k, v := range data { + q.Set(k, fmt.Sprint(v)) + } + + switch req.ResponseType { + case oauth2.Code: + u.RawQuery = q.Encode() + case oauth2.Token: + u.RawQuery = "" + fragment, err := url.QueryUnescape(q.Encode()) + if err != nil { + return "", err + } + u.Fragment = fragment + } + + return u.String(), nil +} + +// CheckResponseType check allows response type +func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { + for _, art := range s.Config.AllowedResponseTypes { + if art == rt { + return true + } + } + return false +} + +// ValidationAuthorizeRequest the authorization request validation +func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + if !(r.Method == "GET" || r.Method == "POST") || + clientID == "" { + return nil, errors.ErrInvalidRequest + } + + resType := oauth2.ResponseType(r.FormValue("response_type")) + if resType.String() == "" { + return nil, errors.ErrUnsupportedResponseType + } else if allowed := s.CheckResponseType(resType); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + req := &AuthorizeRequest{ + RedirectURI: redirectURI, + ResponseType: resType, + ClientID: clientID, + State: r.FormValue("state"), + Scope: r.FormValue("scope"), + Request: r, + } + return req, nil +} + +// GetAuthorizeToken get authorization token(code) +func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { + // check the client allows the grant type + if fn := s.ClientAuthorizedHandler; fn != nil { + gt := oauth2.AuthorizationCode + if req.ResponseType == oauth2.Token { + gt = oauth2.Implicit + } + + allowed, err := fn(req.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + } + + // check the client allows the authorized scope + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) +} + +// GetAuthorizeData get authorization response data +func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { + if rt == oauth2.Code { + return map[string]interface{}{ + "code": ti.GetCode(), + } + } + return s.GetTokenData(ti) +} + +// HandleAuthorizeRequest the authorization request handling +func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + req, err := s.ValidationAuthorizeRequest(r) + if err != nil { + return s.redirectError(w, req, err) + } + + // user authorization + userID, err := s.UserAuthorizationHandler(w, r) + if err != nil { + return s.redirectError(w, req, err) + } else if userID == "" { + return nil + } + req.UserID = userID + + // specify the scope of authorization + if fn := s.AuthorizeScopeHandler; fn != nil { + scope, err := fn(w, r) + if err != nil { + return err + } else if scope != "" { + req.Scope = scope + } + } + + // specify the expiration time of access token + if fn := s.AccessTokenExpHandler; fn != nil { + exp, err := fn(w, r) + if err != nil { + return err + } + req.AccessTokenExp = exp + } + + ti, err := s.GetAuthorizeToken(ctx, req) + if err != nil { + return s.redirectError(w, req, err) + } + + // If the redirect URI is empty, the default domain provided by the client is used. + if req.RedirectURI == "" { + client, err := s.Manager.GetClient(ctx, req.ClientID) + if err != nil { + return err + } + req.RedirectURI = client.GetDomain() + } + + return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) +} + +// ValidationTokenRequest the token request validation +func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { + if v := r.Method; !(v == "POST" || + (s.Config.AllowGetAccessRequest && v == "GET")) { + return "", nil, errors.ErrInvalidRequest + } + + gt := oauth2.GrantType(r.FormValue("grant_type")) + if gt.String() == "" { + return "", nil, errors.ErrUnsupportedGrantType + } + + clientID, clientSecret, err := s.ClientInfoHandler(r) + if err != nil { + return "", nil, err + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + Request: r, + } + + switch gt { + case oauth2.AuthorizationCode: + tgr.RedirectURI = r.FormValue("redirect_uri") + tgr.Code = r.FormValue("code") + if tgr.RedirectURI == "" || + tgr.Code == "" { + return "", nil, errors.ErrInvalidRequest + } + case oauth2.PasswordCredentials: + tgr.Scope = r.FormValue("scope") + username, password := r.FormValue("username"), r.FormValue("password") + if username == "" || password == "" { + return "", nil, errors.ErrInvalidRequest + } + + userID, err := s.PasswordAuthorizationHandler(username, password) + if err != nil { + return "", nil, err + } else if userID == "" { + return "", nil, errors.ErrInvalidGrant + } + tgr.UserID = userID + case oauth2.ClientCredentials: + tgr.Scope = r.FormValue("scope") + case oauth2.Refreshing: + tgr.Refresh = r.FormValue("refresh_token") + tgr.Scope = r.FormValue("scope") + if tgr.Refresh == "" { + return "", nil, errors.ErrInvalidRequest + } + } + return gt, tgr, nil +} + +// CheckGrantType check allows grant type +func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { + for _, agt := range s.Config.AllowedGrantTypes { + if agt == gt { + return true + } + } + return false +} + +// GetAccessToken access token +func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + if allowed := s.CheckGrantType(gt); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + if fn := s.ClientAuthorizedHandler; fn != nil { + allowed, err := fn(tgr.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + switch gt { + case oauth2.AuthorizationCode: + ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) + if err != nil { + switch err { + case errors.ErrInvalidAuthorizeCode: + return nil, errors.ErrInvalidGrant + case errors.ErrInvalidClient: + return nil, errors.ErrInvalidClient + default: + return nil, err + } + } + return ti, nil + case oauth2.PasswordCredentials, oauth2.ClientCredentials: + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + return s.Manager.GenerateAccessToken(ctx, gt, tgr) + case oauth2.Refreshing: + // check scope + if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { + rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + + allowed, err := scopeFn(tgr, rti.GetScope()) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + ti, err := s.Manager.RefreshAccessToken(ctx, tgr) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + return ti, nil + } + + return nil, errors.ErrUnsupportedGrantType +} + +// GetTokenData token data +func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { + data := map[string]interface{}{ + "access_token": ti.GetAccess(), + "token_type": s.Config.TokenType, + "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), + } + + if scope := ti.GetScope(); scope != "" { + data["scope"] = scope + } + + if refresh := ti.GetRefresh(); refresh != "" { + data["refresh_token"] = refresh + } + + if fn := s.ExtensionFieldsHandler; fn != nil { + ext := fn(ti) + for k, v := range ext { + if _, ok := data[k]; ok { + continue + } + data[k] = v + } + } + return data +} + +// HandleTokenRequest token request handling +func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + gt, tgr, err := s.ValidationTokenRequest(r) + if err != nil { + return s.tokenError(w, err) + } + + ti, err := s.GetAccessToken(ctx, gt, tgr) + if err != nil { + return s.tokenError(w, err) + } + + return s.token(w, s.GetTokenData(ti), nil) +} + +// GetErrorData get error response data +func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { + var re errors.Response + if v, ok := errors.Descriptions[err]; ok { + re.Error = err + re.Description = v + re.StatusCode = errors.StatusCodes[err] + } else { + if fn := s.InternalErrorHandler; fn != nil { + if v := fn(err); v != nil { + re = *v + } + } + + if re.Error == nil { + re.Error = errors.ErrServerError + re.Description = errors.Descriptions[errors.ErrServerError] + re.StatusCode = errors.StatusCodes[errors.ErrServerError] + } + } + + if fn := s.ResponseErrorHandler; fn != nil { + fn(&re) + } + + data := make(map[string]interface{}) + if err := re.Error; err != nil { + data["error"] = err.Error() + } + + if v := re.ErrorCode; v != 0 { + data["error_code"] = v + } + + if v := re.Description; v != "" { + data["error_description"] = v + } + + if v := re.URI; v != "" { + data["error_uri"] = v + } + + statusCode := http.StatusInternalServerError + if v := re.StatusCode; v > 0 { + statusCode = v + } + + return data, statusCode, re.Header +} + +// BearerAuth parse bearer token +func (s *Server) BearerAuth(r *http.Request) (string, bool) { + auth := r.Header.Get("Authorization") + prefix := "Bearer " + token := "" + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +// ValidationBearerToken validation the bearer tokens +// https://tools.ietf.org/html/rfc6750 +func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { + ctx := r.Context() + + accessToken, ok := s.BearerAuth(r) + if !ok { + return nil, errors.ErrInvalidAccessToken + } + + return s.Manager.LoadAccessToken(ctx, accessToken) +} From 31d3a1d02aeb3e52b4c689fbf75daf318d5662e8 Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:34:40 +1000 Subject: [PATCH 5/9] Updated --- .editorconfig | 2 +- server/handler.go | 1 - server/server.go | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.editorconfig b/.editorconfig index 6c376a8..74aa1f6 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,2 +1,2 @@ [*.go] -end_of_line = lf \ No newline at end of file +end_of_line = crlf \ No newline at end of file diff --git a/server/handler.go b/server/handler.go index 10e0e96..df66137 100755 --- a/server/handler.go +++ b/server/handler.go @@ -43,7 +43,6 @@ type ( ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) ) - // ClientFormHandler get client data from form func ClientFormHandler(r *http.Request) (string, string, error) { clientID := r.Form.Get("client_id") diff --git a/server/server.go b/server/server.go index 61a8616..6887431 100755 --- a/server/server.go +++ b/server/server.go @@ -55,7 +55,6 @@ type Server struct { AuthorizeScopeHandler AuthorizeScopeHandler } - func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err From 669d51f1156acdd27be003b62d9ef6671842d09e Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:40:15 +1000 Subject: [PATCH 6/9] Updated --- .editorconfig | 2 +- server/server_config.go | 81 +---------------------------------------- 2 files changed, 2 insertions(+), 81 deletions(-) diff --git a/.editorconfig b/.editorconfig index 74aa1f6..d8549da 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,2 +1,2 @@ [*.go] -end_of_line = crlf \ No newline at end of file +end_of_line = cr \ No newline at end of file diff --git a/server/server_config.go b/server/server_config.go index d9b740d..af0477a 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -1,80 +1 @@ -package server - -import ( - "github.com/go-oauth2/oauth2/v4" -) - -// SetTokenType token type -func (s *Server) SetTokenType(tokenType string) { - s.Config.TokenType = tokenType -} - -// SetAllowGetAccessRequest to allow GET requests for the token -func (s *Server) SetAllowGetAccessRequest(allow bool) { - s.Config.AllowGetAccessRequest = allow -} - -// SetAllowedResponseType allow the authorization types -func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) { - s.Config.AllowedResponseTypes = types -} - -// SetAllowedGrantType allow the grant types -func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) { - s.Config.AllowedGrantTypes = types -} - -// SetClientInfoHandler get client info from request -func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) { - s.ClientInfoHandler = handler -} - -// SetClientAuthorizedHandler check the client allows to use this authorization grant type -func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) { - s.ClientAuthorizedHandler = handler -} - -// SetClientScopeHandler check the client allows to use scope -func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) { - s.ClientScopeHandler = handler -} - -// SetUserAuthorizationHandler get user id from request authorization -func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) { - s.UserAuthorizationHandler = handler -} - -// SetPasswordAuthorizationHandler get user id from username and password -func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) { - s.PasswordAuthorizationHandler = handler -} - -// SetRefreshingScopeHandler check the scope of the refreshing token -func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) { - s.RefreshingScopeHandler = handler -} - -// SetResponseErrorHandler response error handling -func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) { - s.ResponseErrorHandler = handler -} - -// SetInternalErrorHandler internal error handling -func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) { - s.InternalErrorHandler = handler -} - -// SetExtensionFieldsHandler in response to the access token with the extension of the field -func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) { - s.ExtensionFieldsHandler = handler -} - -// SetAccessTokenExpHandler set expiration date for the access token -func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) { - s.AccessTokenExpHandler = handler -} - -// SetAuthorizeScopeHandler set scope for the access token -func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { - s.AuthorizeScopeHandler = handler -} +package server import ( "github.com/go-oauth2/oauth2/v4" ) // SetTokenType token type func (s *Server) SetTokenType(tokenType string) { s.Config.TokenType = tokenType } // SetAllowGetAccessRequest to allow GET requests for the token func (s *Server) SetAllowGetAccessRequest(allow bool) { s.Config.AllowGetAccessRequest = allow } // SetAllowedResponseType allow the authorization types func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) { s.Config.AllowedResponseTypes = types } // SetAllowedGrantType allow the grant types func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) { s.Config.AllowedGrantTypes = types } // SetClientInfoHandler get client info from request func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) { s.ClientInfoHandler = handler } // SetClientAuthorizedHandler check the client allows to use this authorization grant type func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) { s.ClientAuthorizedHandler = handler } // SetClientScopeHandler check the client allows to use scope func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) { s.ClientScopeHandler = handler } // SetUserAuthorizationHandler get user id from request authorization func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) { s.UserAuthorizationHandler = handler } // SetPasswordAuthorizationHandler get user id from username and password func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) { s.PasswordAuthorizationHandler = handler } // SetRefreshingScopeHandler check the scope of the refreshing token func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) { s.RefreshingScopeHandler = handler } // SetResponseErrorHandler response error handling func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) { s.ResponseErrorHandler = handler } // SetInternalErrorHandler internal error handling func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) { s.InternalErrorHandler = handler } // SetExtensionFieldsHandler in response to the access token with the extension of the field func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) { s.ExtensionFieldsHandler = handler } // SetAccessTokenExpHandler set expiration date for the access token func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) { s.AccessTokenExpHandler = handler } // SetAuthorizeScopeHandler set scope for the access token func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { s.AuthorizeScopeHandler = handler } \ No newline at end of file From d7657a0a5ce24df7e3a46f922e3e6d6e4781364b Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:41:12 +1000 Subject: [PATCH 7/9] LF --- .editorconfig | 2 +- manage/manager.go | 472 +---------------------------------- server/handler.go | 64 +---- server/server.go | 531 +--------------------------------------- server/server_config.go | 81 +++++- 5 files changed, 84 insertions(+), 1066 deletions(-) diff --git a/.editorconfig b/.editorconfig index d8549da..6c376a8 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,2 +1,2 @@ [*.go] -end_of_line = cr \ No newline at end of file +end_of_line = lf \ No newline at end of file diff --git a/manage/manager.go b/manage/manager.go index b8fb01e..0ff1deb 100755 --- a/manage/manager.go +++ b/manage/manager.go @@ -1,471 +1 @@ -package manage - -import ( - "context" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" - "github.com/go-oauth2/oauth2/v4/generates" - "github.com/go-oauth2/oauth2/v4/models" -) - -// NewDefaultManager create to default authorization management instance -func NewDefaultManager() *Manager { - m := NewManager() - // default implementation - m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) - m.MapAccessGenerate(generates.NewAccessGenerate()) - - return m -} - -// NewManager create to authorization management instance -func NewManager() *Manager { - return &Manager{ - gtcfg: make(map[oauth2.GrantType]*Config), - validateURI: DefaultValidateURI, - } -} - -// Manager provide authorization management -type Manager struct { - codeExp time.Duration - gtcfg map[oauth2.GrantType]*Config - rcfg *RefreshingConfig - validateURI ValidateURIHandler - authorizeGenerate oauth2.AuthorizeGenerate - accessGenerate oauth2.AccessGenerate - tokenStore oauth2.TokenStore - clientStore oauth2.ClientStore -} - -// get grant type config -func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { - if c, ok := m.gtcfg[gt]; ok && c != nil { - return c - } - switch gt { - case oauth2.AuthorizationCode: - return DefaultAuthorizeCodeTokenCfg - case oauth2.Implicit: - return DefaultImplicitTokenCfg - case oauth2.PasswordCredentials: - return DefaultPasswordTokenCfg - case oauth2.ClientCredentials: - return DefaultClientTokenCfg - } - return &Config{} -} - -// SetAuthorizeCodeExp set the authorization code expiration time -func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { - m.codeExp = exp -} - -// SetAuthorizeCodeTokenCfg set the authorization code grant token config -func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { - m.gtcfg[oauth2.AuthorizationCode] = cfg -} - -// SetImplicitTokenCfg set the implicit grant token config -func (m *Manager) SetImplicitTokenCfg(cfg *Config) { - m.gtcfg[oauth2.Implicit] = cfg -} - -// SetPasswordTokenCfg set the password grant token config -func (m *Manager) SetPasswordTokenCfg(cfg *Config) { - m.gtcfg[oauth2.PasswordCredentials] = cfg -} - -// SetClientTokenCfg set the client grant token config -func (m *Manager) SetClientTokenCfg(cfg *Config) { - m.gtcfg[oauth2.ClientCredentials] = cfg -} - -// SetRefreshTokenCfg set the refreshing token config -func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { - m.rcfg = cfg -} - -// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI -func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { - m.validateURI = handler -} - -// MapAuthorizeGenerate mapping the authorize code generate interface -func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { - m.authorizeGenerate = gen -} - -// MapAccessGenerate mapping the access token generate interface -func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { - m.accessGenerate = gen -} - -// MapClientStorage mapping the client store interface -func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { - m.clientStore = stor -} - -// MustClientStorage mandatory mapping the client store interface -func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { - if err != nil { - panic(err.Error()) - } - m.clientStore = stor -} - -// MapTokenStorage mapping the token store interface -func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { - m.tokenStore = stor -} - -// MustTokenStorage mandatory mapping the token store interface -func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { - if err != nil { - panic(err) - } - m.tokenStore = stor -} - -// GetClient get the client information -func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { - cli, err = m.clientStore.GetByID(ctx, clientID) - if err != nil { - return - } else if cli == nil { - err = errors.ErrInvalidClient - } - return -} - -// GenerateAuthToken generate the authorization token(code) -func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } else if tgr.RedirectURI != "" { - if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { - return nil, err - } - } - - ti := models.NewToken() - ti.SetClientID(tgr.ClientID) - ti.SetUserID(tgr.UserID) - ti.SetRedirectURI(tgr.RedirectURI) - ti.SetScope(tgr.Scope) - - createAt := time.Now() - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: tgr.UserID, - CreateAt: createAt, - TokenInfo: ti, - Request: tgr.Request, - } - switch rt { - case oauth2.Code: - codeExp := m.codeExp - if codeExp == 0 { - codeExp = DefaultCodeExp - } - ti.SetCodeCreateAt(createAt) - ti.SetCodeExpiresIn(codeExp) - if exp := tgr.AccessTokenExp; exp > 0 { - ti.SetAccessExpiresIn(exp) - } - - tv, err := m.authorizeGenerate.Token(ctx, td) - if err != nil { - return nil, err - } - ti.SetCode(tv) - case oauth2.Token: - // set access token expires - icfg := m.grantConfig(oauth2.Implicit) - aexp := icfg.AccessTokenExp - if exp := tgr.AccessTokenExp; exp > 0 { - aexp = exp - } - ti.SetAccessCreateAt(createAt) - ti.SetAccessExpiresIn(aexp) - - if icfg.IsGenerateRefresh { - ti.SetRefreshCreateAt(createAt) - ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) - } - - tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - ti.SetAccess(tv) - - if rv != "" { - ti.SetRefresh(rv) - } - } - - err = m.tokenStore.Create(ctx, ti) - if err != nil { - return nil, err - } - return ti, nil -} - -// get authorization code data -func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - ti, err := m.tokenStore.GetByCode(ctx, code) - if err != nil { - return nil, err - } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { - err = errors.ErrInvalidAuthorizeCode - return nil, errors.ErrInvalidAuthorizeCode - } - return ti, nil -} - -// delete authorization code data -func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { - return m.tokenStore.RemoveByCode(ctx, code) -} - -// get and delete authorization code data -func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - code := tgr.Code - ti, err := m.getAuthorizationCode(ctx, code) - if err != nil { - return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidAuthorizeCode - } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { - return nil, errors.ErrInvalidAuthorizeCode - } - - err = m.delAuthorizationCode(ctx, code) - if err != nil { - return nil, err - } - return ti, nil -} - -// GenerateAccessToken generate the access token -func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } - if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { - if !cliPass.VerifyPassword(tgr.ClientSecret) { - return nil, errors.ErrInvalidClient - } - } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient - } - if tgr.RedirectURI != "" { - if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { - return nil, err - } - } - - if gt == oauth2.AuthorizationCode { - ti, err := m.getAndDelAuthorizationCode(ctx, tgr) - if err != nil { - return nil, err - } - tgr.UserID = ti.GetUserID() - tgr.Scope = ti.GetScope() - if exp := ti.GetAccessExpiresIn(); exp > 0 { - tgr.AccessTokenExp = exp - } - } - - ti := models.NewToken() - ti.SetClientID(tgr.ClientID) - ti.SetUserID(tgr.UserID) - ti.SetRedirectURI(tgr.RedirectURI) - ti.SetScope(tgr.Scope) - - createAt := time.Now() - ti.SetAccessCreateAt(createAt) - - // set access token expires - gcfg := m.grantConfig(gt) - aexp := gcfg.AccessTokenExp - if exp := tgr.AccessTokenExp; exp > 0 { - aexp = exp - } - ti.SetAccessExpiresIn(aexp) - if gcfg.IsGenerateRefresh { - ti.SetRefreshCreateAt(createAt) - ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) - } - - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: tgr.UserID, - CreateAt: createAt, - TokenInfo: ti, - Request: tgr.Request, - } - - av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - ti.SetAccess(av) - - if rv != "" { - ti.SetRefresh(rv) - } - - err = m.tokenStore.Create(ctx, ti) - if err != nil { - return nil, err - } - - return ti, nil -} - -// RefreshAccessToken refreshing an access token -func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } else if tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient - } - - ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidRefreshToken - } - - oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() - - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: ti.GetUserID(), - CreateAt: time.Now(), - TokenInfo: ti, - Request: tgr.Request, - } - - rcfg := DefaultRefreshTokenCfg - if v := m.rcfg; v != nil { - rcfg = v - } - - ti.SetAccessCreateAt(td.CreateAt) - if v := rcfg.AccessTokenExp; v > 0 { - ti.SetAccessExpiresIn(v) - } - - if v := rcfg.RefreshTokenExp; v > 0 { - ti.SetRefreshExpiresIn(v) - } - - if rcfg.IsResetRefreshTime { - ti.SetRefreshCreateAt(td.CreateAt) - } - - if scope := tgr.Scope; scope != "" { - ti.SetScope(scope) - } - - tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - - ti.SetAccess(tv) - if rv != "" { - ti.SetRefresh(rv) - } - - if err := m.tokenStore.Create(ctx, ti); err != nil { - return nil, err - } - - if rcfg.IsRemoveAccess { - // remove the old access token - if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { - return nil, err - } - } - - if rcfg.IsRemoveRefreshing && rv != "" { - // remove the old refresh token - if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { - return nil, err - } - } - - if rv == "" { - ti.SetRefresh("") - ti.SetRefreshCreateAt(time.Now()) - ti.SetRefreshExpiresIn(0) - } - - return ti, nil -} - -// RemoveAccessToken use the access token to delete the token information -func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { - if access == "" { - return errors.ErrInvalidAccessToken - } - return m.tokenStore.RemoveByAccess(ctx, access) -} - -// RemoveRefreshToken use the refresh token to delete the token information -func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { - if refresh == "" { - return errors.ErrInvalidAccessToken - } - return m.tokenStore.RemoveByRefresh(ctx, refresh) -} - -// LoadAccessToken according to the access token for corresponding token information -func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { - if access == "" { - return nil, errors.ErrInvalidAccessToken - } - - ct := time.Now() - ti, err := m.tokenStore.GetByAccess(ctx, access) - if err != nil { - return nil, err - } else if ti == nil || ti.GetAccess() != access { - return nil, errors.ErrInvalidAccessToken - } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { - return nil, errors.ErrExpiredRefreshToken - } else if ti.GetAccessExpiresIn() != 0 && - ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { - return nil, errors.ErrExpiredAccessToken - } - return ti, nil -} - -// LoadRefreshToken according to the refresh token for corresponding token information -func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - if refresh == "" { - return nil, errors.ErrInvalidRefreshToken - } - - ti, err := m.tokenStore.GetByRefresh(ctx, refresh) - if err != nil { - return nil, err - } else if ti == nil || ti.GetRefresh() != refresh { - return nil, errors.ErrInvalidRefreshToken - } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { - return nil, errors.ErrExpiredRefreshToken - } - return ti, nil -} +package manage import ( "context" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" "github.com/go-oauth2/oauth2/v4/generates" "github.com/go-oauth2/oauth2/v4/models" ) // NewDefaultManager create to default authorization management instance func NewDefaultManager() *Manager { m := NewManager() // default implementation m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) m.MapAccessGenerate(generates.NewAccessGenerate()) return m } // NewManager create to authorization management instance func NewManager() *Manager { return &Manager{ gtcfg: make(map[oauth2.GrantType]*Config), validateURI: DefaultValidateURI, } } // Manager provide authorization management type Manager struct { codeExp time.Duration gtcfg map[oauth2.GrantType]*Config rcfg *RefreshingConfig validateURI ValidateURIHandler authorizeGenerate oauth2.AuthorizeGenerate accessGenerate oauth2.AccessGenerate tokenStore oauth2.TokenStore clientStore oauth2.ClientStore } // get grant type config func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { if c, ok := m.gtcfg[gt]; ok && c != nil { return c } switch gt { case oauth2.AuthorizationCode: return DefaultAuthorizeCodeTokenCfg case oauth2.Implicit: return DefaultImplicitTokenCfg case oauth2.PasswordCredentials: return DefaultPasswordTokenCfg case oauth2.ClientCredentials: return DefaultClientTokenCfg } return &Config{} } // SetAuthorizeCodeExp set the authorization code expiration time func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { m.codeExp = exp } // SetAuthorizeCodeTokenCfg set the authorization code grant token config func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { m.gtcfg[oauth2.AuthorizationCode] = cfg } // SetImplicitTokenCfg set the implicit grant token config func (m *Manager) SetImplicitTokenCfg(cfg *Config) { m.gtcfg[oauth2.Implicit] = cfg } // SetPasswordTokenCfg set the password grant token config func (m *Manager) SetPasswordTokenCfg(cfg *Config) { m.gtcfg[oauth2.PasswordCredentials] = cfg } // SetClientTokenCfg set the client grant token config func (m *Manager) SetClientTokenCfg(cfg *Config) { m.gtcfg[oauth2.ClientCredentials] = cfg } // SetRefreshTokenCfg set the refreshing token config func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { m.rcfg = cfg } // SetValidateURIHandler set the validates that RedirectURI is contained in baseURI func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { m.validateURI = handler } // MapAuthorizeGenerate mapping the authorize code generate interface func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { m.authorizeGenerate = gen } // MapAccessGenerate mapping the access token generate interface func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { m.accessGenerate = gen } // MapClientStorage mapping the client store interface func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { m.clientStore = stor } // MustClientStorage mandatory mapping the client store interface func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { if err != nil { panic(err.Error()) } m.clientStore = stor } // MapTokenStorage mapping the token store interface func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { m.tokenStore = stor } // MustTokenStorage mandatory mapping the token store interface func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { if err != nil { panic(err) } m.tokenStore = stor } // GetClient get the client information func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { cli, err = m.clientStore.GetByID(ctx, clientID) if err != nil { return } else if cli == nil { err = errors.ErrInvalidClient } return } // GenerateAuthToken generate the authorization token(code) func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } else if tgr.RedirectURI != "" { if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { return nil, err } } ti := models.NewToken() ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) createAt := time.Now() td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: createAt, TokenInfo: ti, Request: tgr.Request, } switch rt { case oauth2.Code: codeExp := m.codeExp if codeExp == 0 { codeExp = DefaultCodeExp } ti.SetCodeCreateAt(createAt) ti.SetCodeExpiresIn(codeExp) if exp := tgr.AccessTokenExp; exp > 0 { ti.SetAccessExpiresIn(exp) } tv, err := m.authorizeGenerate.Token(ctx, td) if err != nil { return nil, err } ti.SetCode(tv) case oauth2.Token: // set access token expires icfg := m.grantConfig(oauth2.Implicit) aexp := icfg.AccessTokenExp if exp := tgr.AccessTokenExp; exp > 0 { aexp = exp } ti.SetAccessCreateAt(createAt) ti.SetAccessExpiresIn(aexp) if icfg.IsGenerateRefresh { ti.SetRefreshCreateAt(createAt) ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) } tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(tv) if rv != "" { ti.SetRefresh(rv) } } err = m.tokenStore.Create(ctx, ti) if err != nil { return nil, err } return ti, nil } // get authorization code data func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { ti, err := m.tokenStore.GetByCode(ctx, code) if err != nil { return nil, err } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { err = errors.ErrInvalidAuthorizeCode return nil, errors.ErrInvalidAuthorizeCode } return ti, nil } // delete authorization code data func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { return m.tokenStore.RemoveByCode(ctx, code) } // get and delete authorization code data func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { code := tgr.Code ti, err := m.getAuthorizationCode(ctx, code) if err != nil { return nil, err } else if ti.GetClientID() != tgr.ClientID { return nil, errors.ErrInvalidAuthorizeCode } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { return nil, errors.ErrInvalidAuthorizeCode } err = m.delAuthorizationCode(ctx, code) if err != nil { return nil, err } return ti, nil } // GenerateAccessToken generate the access token func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { if !cliPass.VerifyPassword(tgr.ClientSecret) { return nil, errors.ErrInvalidClient } } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { return nil, errors.ErrInvalidClient } if tgr.RedirectURI != "" { if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { return nil, err } } if gt == oauth2.AuthorizationCode { ti, err := m.getAndDelAuthorizationCode(ctx, tgr) if err != nil { return nil, err } tgr.UserID = ti.GetUserID() tgr.Scope = ti.GetScope() if exp := ti.GetAccessExpiresIn(); exp > 0 { tgr.AccessTokenExp = exp } } ti := models.NewToken() ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) createAt := time.Now() ti.SetAccessCreateAt(createAt) // set access token expires gcfg := m.grantConfig(gt) aexp := gcfg.AccessTokenExp if exp := tgr.AccessTokenExp; exp > 0 { aexp = exp } ti.SetAccessExpiresIn(aexp) if gcfg.IsGenerateRefresh { ti.SetRefreshCreateAt(createAt) ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) } td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: createAt, TokenInfo: ti, Request: tgr.Request, } av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(av) if rv != "" { ti.SetRefresh(rv) } err = m.tokenStore.Create(ctx, ti) if err != nil { return nil, err } return ti, nil } // RefreshAccessToken refreshing an access token func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } else if tgr.ClientSecret != cli.GetSecret() { return nil, errors.ErrInvalidClient } ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { return nil, err } else if ti.GetClientID() != tgr.ClientID { return nil, errors.ErrInvalidRefreshToken } oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() td := &oauth2.GenerateBasic{ Client: cli, UserID: ti.GetUserID(), CreateAt: time.Now(), TokenInfo: ti, Request: tgr.Request, } rcfg := DefaultRefreshTokenCfg if v := m.rcfg; v != nil { rcfg = v } ti.SetAccessCreateAt(td.CreateAt) if v := rcfg.AccessTokenExp; v > 0 { ti.SetAccessExpiresIn(v) } if v := rcfg.RefreshTokenExp; v > 0 { ti.SetRefreshExpiresIn(v) } if rcfg.IsResetRefreshTime { ti.SetRefreshCreateAt(td.CreateAt) } if scope := tgr.Scope; scope != "" { ti.SetScope(scope) } tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(tv) if rv != "" { ti.SetRefresh(rv) } if err := m.tokenStore.Create(ctx, ti); err != nil { return nil, err } if rcfg.IsRemoveAccess { // remove the old access token if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { return nil, err } } if rcfg.IsRemoveRefreshing && rv != "" { // remove the old refresh token if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { return nil, err } } if rv == "" { ti.SetRefresh("") ti.SetRefreshCreateAt(time.Now()) ti.SetRefreshExpiresIn(0) } return ti, nil } // RemoveAccessToken use the access token to delete the token information func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { if access == "" { return errors.ErrInvalidAccessToken } return m.tokenStore.RemoveByAccess(ctx, access) } // RemoveRefreshToken use the refresh token to delete the token information func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { if refresh == "" { return errors.ErrInvalidAccessToken } return m.tokenStore.RemoveByRefresh(ctx, refresh) } // LoadAccessToken according to the access token for corresponding token information func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { if access == "" { return nil, errors.ErrInvalidAccessToken } ct := time.Now() ti, err := m.tokenStore.GetByAccess(ctx, access) if err != nil { return nil, err } else if ti == nil || ti.GetAccess() != access { return nil, errors.ErrInvalidAccessToken } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { return nil, errors.ErrExpiredRefreshToken } else if ti.GetAccessExpiresIn() != 0 && ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { return nil, errors.ErrExpiredAccessToken } return ti, nil } // LoadRefreshToken according to the refresh token for corresponding token information func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { if refresh == "" { return nil, errors.ErrInvalidRefreshToken } ti, err := m.tokenStore.GetByRefresh(ctx, refresh) if err != nil { return nil, err } else if ti == nil || ti.GetRefresh() != refresh { return nil, errors.ErrInvalidRefreshToken } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { return nil, errors.ErrExpiredRefreshToken } return ti, nil } \ No newline at end of file diff --git a/server/handler.go b/server/handler.go index df66137..0474605 100755 --- a/server/handler.go +++ b/server/handler.go @@ -1,63 +1 @@ -package server - -import ( - "net/http" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -type ( - // ClientInfoHandler get client info from request - ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) - - // ClientAuthorizedHandler check the client allows to use this authorization grant type - ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) - - // ClientScopeHandler check the client allows to use scope - ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) - - // UserAuthorizationHandler get user id from request authorization - UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) - - // PasswordAuthorizationHandler get user id from username and password - PasswordAuthorizationHandler func(username, password string) (userID string, err error) - - // RefreshingScopeHandler check the scope of the refreshing token - RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) - - // ResponseErrorHandler response error handing - ResponseErrorHandler func(re *errors.Response) - - // InternalErrorHandler internal error handing - InternalErrorHandler func(err error) (re *errors.Response) - - // AuthorizeScopeHandler set the authorized scope - AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) - - // AccessTokenExpHandler set expiration date for the access token - AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) - - // ExtensionFieldsHandler in response to the access token with the extension of the field - ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) -) - -// ClientFormHandler get client data from form -func ClientFormHandler(r *http.Request) (string, string, error) { - clientID := r.Form.Get("client_id") - if clientID == "" { - return "", "", errors.ErrInvalidClient - } - clientSecret := r.Form.Get("client_secret") - return clientID, clientSecret, nil -} - -// ClientBasicHandler get client data from basic authorization -func ClientBasicHandler(r *http.Request) (string, string, error) { - username, password, ok := r.BasicAuth() - if !ok { - return "", "", errors.ErrInvalidClient - } - return username, password, nil -} +package server import ( "net/http" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" ) type ( // ClientInfoHandler get client info from request ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) // ClientAuthorizedHandler check the client allows to use this authorization grant type ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) // ClientScopeHandler check the client allows to use scope ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) // UserAuthorizationHandler get user id from request authorization UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) // PasswordAuthorizationHandler get user id from username and password PasswordAuthorizationHandler func(username, password string) (userID string, err error) // RefreshingScopeHandler check the scope of the refreshing token RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) // ResponseErrorHandler response error handing ResponseErrorHandler func(re *errors.Response) // InternalErrorHandler internal error handing InternalErrorHandler func(err error) (re *errors.Response) // AuthorizeScopeHandler set the authorized scope AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) // AccessTokenExpHandler set expiration date for the access token AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) // ExtensionFieldsHandler in response to the access token with the extension of the field ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) ) // ClientFormHandler get client data from form func ClientFormHandler(r *http.Request) (string, string, error) { clientID := r.Form.Get("client_id") if clientID == "" { return "", "", errors.ErrInvalidClient } clientSecret := r.Form.Get("client_secret") return clientID, clientSecret, nil } // ClientBasicHandler get client data from basic authorization func ClientBasicHandler(r *http.Request) (string, string, error) { username, password, ok := r.BasicAuth() if !ok { return "", "", errors.ErrInvalidClient } return username, password, nil } \ No newline at end of file diff --git a/server/server.go b/server/server.go index 6887431..effc8f6 100755 --- a/server/server.go +++ b/server/server.go @@ -1,530 +1 @@ -package server - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -// NewDefaultServer create a default authorization server -func NewDefaultServer(manager oauth2.Manager) *Server { - return NewServer(NewConfig(), manager) -} - -// NewServer create authorization server -func NewServer(cfg *Config, manager oauth2.Manager) *Server { - srv := &Server{ - Config: cfg, - Manager: manager, - } - - // default handler - srv.ClientInfoHandler = ClientBasicHandler - - srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { - return "", errors.ErrAccessDenied - } - - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { - return "", errors.ErrAccessDenied - } - return srv -} - -// Server Provide authorization server -type Server struct { - Config *Config - Manager oauth2.Manager - ClientInfoHandler ClientInfoHandler - ClientAuthorizedHandler ClientAuthorizedHandler - ClientScopeHandler ClientScopeHandler - UserAuthorizationHandler UserAuthorizationHandler - PasswordAuthorizationHandler PasswordAuthorizationHandler - RefreshingScopeHandler RefreshingScopeHandler - ResponseErrorHandler ResponseErrorHandler - InternalErrorHandler InternalErrorHandler - ExtensionFieldsHandler ExtensionFieldsHandler - AccessTokenExpHandler AccessTokenExpHandler - AuthorizeScopeHandler AuthorizeScopeHandler -} - -func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { - if req == nil { - return err - } - data, _, _ := s.GetErrorData(err) - return s.redirect(w, req, data) -} - -func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { - uri, err := s.GetRedirectURI(req, data) - if err != nil { - return err - } - - w.Header().Set("Location", uri) - w.WriteHeader(302) - return nil -} - -func (s *Server) tokenError(w http.ResponseWriter, err error) error { - data, statusCode, header := s.GetErrorData(err) - return s.token(w, data, header, statusCode) -} - -func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { - w.Header().Set("Content-Type", "application/json;charset=UTF-8") - w.Header().Set("Cache-Control", "no-store") - w.Header().Set("Pragma", "no-cache") - - for key := range header { - w.Header().Set(key, header.Get(key)) - } - - status := http.StatusOK - if len(statusCode) > 0 && statusCode[0] > 0 { - status = statusCode[0] - } - - w.WriteHeader(status) - return json.NewEncoder(w).Encode(data) -} - -// GetRedirectURI get redirect uri -func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { - u, err := url.Parse(req.RedirectURI) - if err != nil { - return "", err - } - - q := u.Query() - if req.State != "" { - q.Set("state", req.State) - } - - for k, v := range data { - q.Set(k, fmt.Sprint(v)) - } - - switch req.ResponseType { - case oauth2.Code: - u.RawQuery = q.Encode() - case oauth2.Token: - u.RawQuery = "" - fragment, err := url.QueryUnescape(q.Encode()) - if err != nil { - return "", err - } - u.Fragment = fragment - } - - return u.String(), nil -} - -// CheckResponseType check allows response type -func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.Config.AllowedResponseTypes { - if art == rt { - return true - } - } - return false -} - -// ValidationAuthorizeRequest the authorization request validation -func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { - redirectURI := r.FormValue("redirect_uri") - clientID := r.FormValue("client_id") - if !(r.Method == "GET" || r.Method == "POST") || - clientID == "" { - return nil, errors.ErrInvalidRequest - } - - resType := oauth2.ResponseType(r.FormValue("response_type")) - if resType.String() == "" { - return nil, errors.ErrUnsupportedResponseType - } else if allowed := s.CheckResponseType(resType); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - req := &AuthorizeRequest{ - RedirectURI: redirectURI, - ResponseType: resType, - ClientID: clientID, - State: r.FormValue("state"), - Scope: r.FormValue("scope"), - Request: r, - } - return req, nil -} - -// GetAuthorizeToken get authorization token(code) -func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { - // check the client allows the grant type - if fn := s.ClientAuthorizedHandler; fn != nil { - gt := oauth2.AuthorizationCode - if req.ResponseType == oauth2.Token { - gt = oauth2.Implicit - } - - allowed, err := fn(req.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, - } - - // check the client allows the authorized scope - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) -} - -// GetAuthorizeData get authorization response data -func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { - if rt == oauth2.Code { - return map[string]interface{}{ - "code": ti.GetCode(), - } - } - return s.GetTokenData(ti) -} - -// HandleAuthorizeRequest the authorization request handling -func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - req, err := s.ValidationAuthorizeRequest(r) - if err != nil { - return s.redirectError(w, req, err) - } - - // user authorization - userID, err := s.UserAuthorizationHandler(w, r) - if err != nil { - return s.redirectError(w, req, err) - } else if userID == "" { - return nil - } - req.UserID = userID - - // specify the scope of authorization - if fn := s.AuthorizeScopeHandler; fn != nil { - scope, err := fn(w, r) - if err != nil { - return err - } else if scope != "" { - req.Scope = scope - } - } - - // specify the expiration time of access token - if fn := s.AccessTokenExpHandler; fn != nil { - exp, err := fn(w, r) - if err != nil { - return err - } - req.AccessTokenExp = exp - } - - ti, err := s.GetAuthorizeToken(ctx, req) - if err != nil { - return s.redirectError(w, req, err) - } - - // If the redirect URI is empty, the default domain provided by the client is used. - if req.RedirectURI == "" { - client, err := s.Manager.GetClient(ctx, req.ClientID) - if err != nil { - return err - } - req.RedirectURI = client.GetDomain() - } - - return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) -} - -// ValidationTokenRequest the token request validation -func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { - if v := r.Method; !(v == "POST" || - (s.Config.AllowGetAccessRequest && v == "GET")) { - return "", nil, errors.ErrInvalidRequest - } - - gt := oauth2.GrantType(r.FormValue("grant_type")) - if gt.String() == "" { - return "", nil, errors.ErrUnsupportedGrantType - } - - clientID, clientSecret, err := s.ClientInfoHandler(r) - if err != nil { - return "", nil, err - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: clientID, - ClientSecret: clientSecret, - Request: r, - } - - switch gt { - case oauth2.AuthorizationCode: - tgr.RedirectURI = r.FormValue("redirect_uri") - tgr.Code = r.FormValue("code") - if tgr.RedirectURI == "" || - tgr.Code == "" { - return "", nil, errors.ErrInvalidRequest - } - case oauth2.PasswordCredentials: - tgr.Scope = r.FormValue("scope") - username, password := r.FormValue("username"), r.FormValue("password") - if username == "" || password == "" { - return "", nil, errors.ErrInvalidRequest - } - - userID, err := s.PasswordAuthorizationHandler(username, password) - if err != nil { - return "", nil, err - } else if userID == "" { - return "", nil, errors.ErrInvalidGrant - } - tgr.UserID = userID - case oauth2.ClientCredentials: - tgr.Scope = r.FormValue("scope") - case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") - tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest - } - } - return gt, tgr, nil -} - -// CheckGrantType check allows grant type -func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { - for _, agt := range s.Config.AllowedGrantTypes { - if agt == gt { - return true - } - } - return false -} - -// GetAccessToken access token -func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - if allowed := s.CheckGrantType(gt); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - if fn := s.ClientAuthorizedHandler; fn != nil { - allowed, err := fn(tgr.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - switch gt { - case oauth2.AuthorizationCode: - ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) - if err != nil { - switch err { - case errors.ErrInvalidAuthorizeCode: - return nil, errors.ErrInvalidGrant - case errors.ErrInvalidClient: - return nil, errors.ErrInvalidClient - default: - return nil, err - } - } - return ti, nil - case oauth2.PasswordCredentials, oauth2.ClientCredentials: - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - return s.Manager.GenerateAccessToken(ctx, gt, tgr) - case oauth2.Refreshing: - // check scope - if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { - rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - - allowed, err := scopeFn(tgr, rti.GetScope()) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - ti, err := s.Manager.RefreshAccessToken(ctx, tgr) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - return ti, nil - } - - return nil, errors.ErrUnsupportedGrantType -} - -// GetTokenData token data -func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { - data := map[string]interface{}{ - "access_token": ti.GetAccess(), - "token_type": s.Config.TokenType, - "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), - } - - if scope := ti.GetScope(); scope != "" { - data["scope"] = scope - } - - if refresh := ti.GetRefresh(); refresh != "" { - data["refresh_token"] = refresh - } - - if fn := s.ExtensionFieldsHandler; fn != nil { - ext := fn(ti) - for k, v := range ext { - if _, ok := data[k]; ok { - continue - } - data[k] = v - } - } - return data -} - -// HandleTokenRequest token request handling -func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - gt, tgr, err := s.ValidationTokenRequest(r) - if err != nil { - return s.tokenError(w, err) - } - - ti, err := s.GetAccessToken(ctx, gt, tgr) - if err != nil { - return s.tokenError(w, err) - } - - return s.token(w, s.GetTokenData(ti), nil) -} - -// GetErrorData get error response data -func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { - var re errors.Response - if v, ok := errors.Descriptions[err]; ok { - re.Error = err - re.Description = v - re.StatusCode = errors.StatusCodes[err] - } else { - if fn := s.InternalErrorHandler; fn != nil { - if v := fn(err); v != nil { - re = *v - } - } - - if re.Error == nil { - re.Error = errors.ErrServerError - re.Description = errors.Descriptions[errors.ErrServerError] - re.StatusCode = errors.StatusCodes[errors.ErrServerError] - } - } - - if fn := s.ResponseErrorHandler; fn != nil { - fn(&re) - } - - data := make(map[string]interface{}) - if err := re.Error; err != nil { - data["error"] = err.Error() - } - - if v := re.ErrorCode; v != 0 { - data["error_code"] = v - } - - if v := re.Description; v != "" { - data["error_description"] = v - } - - if v := re.URI; v != "" { - data["error_uri"] = v - } - - statusCode := http.StatusInternalServerError - if v := re.StatusCode; v > 0 { - statusCode = v - } - - return data, statusCode, re.Header -} - -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - -// ValidationBearerToken validation the bearer tokens -// https://tools.ietf.org/html/rfc6750 -func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { - ctx := r.Context() - - accessToken, ok := s.BearerAuth(r) - if !ok { - return nil, errors.ErrInvalidAccessToken - } - - return s.Manager.LoadAccessToken(ctx, accessToken) -} +package server import ( "context" "encoding/json" "fmt" "net/http" "net/url" "strings" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" ) // NewDefaultServer create a default authorization server func NewDefaultServer(manager oauth2.Manager) *Server { return NewServer(NewConfig(), manager) } // NewServer create authorization server func NewServer(cfg *Config, manager oauth2.Manager) *Server { srv := &Server{ Config: cfg, Manager: manager, } // default handler srv.ClientInfoHandler = ClientBasicHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv } // Server Provide authorization server type Server struct { Config *Config Manager oauth2.Manager ClientInfoHandler ClientInfoHandler ClientAuthorizedHandler ClientAuthorizedHandler ClientScopeHandler ClientScopeHandler UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { uri, err := s.GetRedirectURI(req, data) if err != nil { return err } w.Header().Set("Location", uri) w.WriteHeader(302) return nil } func (s *Server) tokenError(w http.ResponseWriter, err error) error { data, statusCode, header := s.GetErrorData(err) return s.token(w, data, header, statusCode) } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") for key := range header { w.Header().Set(key, header.Get(key)) } status := http.StatusOK if len(statusCode) > 0 && statusCode[0] > 0 { status = statusCode[0] } w.WriteHeader(status) return json.NewEncoder(w).Encode(data) } // GetRedirectURI get redirect uri func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { u, err := url.Parse(req.RedirectURI) if err != nil { return "", err } q := u.Query() if req.State != "" { q.Set("state", req.State) } for k, v := range data { q.Set(k, fmt.Sprint(v)) } switch req.ResponseType { case oauth2.Code: u.RawQuery = q.Encode() case oauth2.Token: u.RawQuery = "" fragment, err := url.QueryUnescape(q.Encode()) if err != nil { return "", err } u.Fragment = fragment } return u.String(), nil } // CheckResponseType check allows response type func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { for _, art := range s.Config.AllowedResponseTypes { if art == rt { return true } } return false } // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") clientID := r.FormValue("client_id") if !(r.Method == "GET" || r.Method == "POST") || clientID == "" { return nil, errors.ErrInvalidRequest } resType := oauth2.ResponseType(r.FormValue("response_type")) if resType.String() == "" { return nil, errors.ErrUnsupportedResponseType } else if allowed := s.CheckResponseType(resType); !allowed { return nil, errors.ErrUnauthorizedClient } req := &AuthorizeRequest{ RedirectURI: redirectURI, ResponseType: resType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), Request: r, } return req, nil } // GetAuthorizeToken get authorization token(code) func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { // check the client allows the grant type if fn := s.ClientAuthorizedHandler; fn != nil { gt := oauth2.AuthorizationCode if req.ResponseType == oauth2.Token { gt = oauth2.Implicit } allowed, err := fn(req.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } tgr := &oauth2.TokenGenerateRequest{ ClientID: req.ClientID, UserID: req.UserID, RedirectURI: req.RedirectURI, Scope: req.Scope, AccessTokenExp: req.AccessTokenExp, Request: req.Request, } // check the client allows the authorized scope if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) } // GetAuthorizeData get authorization response data func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { if rt == oauth2.Code { return map[string]interface{}{ "code": ti.GetCode(), } } return s.GetTokenData(ti) } // HandleAuthorizeRequest the authorization request handling func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() req, err := s.ValidationAuthorizeRequest(r) if err != nil { return s.redirectError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { return s.redirectError(w, req, err) } else if userID == "" { return nil } req.UserID = userID // specify the scope of authorization if fn := s.AuthorizeScopeHandler; fn != nil { scope, err := fn(w, r) if err != nil { return err } else if scope != "" { req.Scope = scope } } // specify the expiration time of access token if fn := s.AccessTokenExpHandler; fn != nil { exp, err := fn(w, r) if err != nil { return err } req.AccessTokenExp = exp } ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { return s.redirectError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. if req.RedirectURI == "" { client, err := s.Manager.GetClient(ctx, req.ClientID) if err != nil { return err } req.RedirectURI = client.GetDomain() } return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) } // ValidationTokenRequest the token request validation func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { if v := r.Method; !(v == "POST" || (s.Config.AllowGetAccessRequest && v == "GET")) { return "", nil, errors.ErrInvalidRequest } gt := oauth2.GrantType(r.FormValue("grant_type")) if gt.String() == "" { return "", nil, errors.ErrUnsupportedGrantType } clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return "", nil, err } tgr := &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, Request: r, } switch gt { case oauth2.AuthorizationCode: tgr.RedirectURI = r.FormValue("redirect_uri") tgr.Code = r.FormValue("code") if tgr.RedirectURI == "" || tgr.Code == "" { return "", nil, errors.ErrInvalidRequest } case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") username, password := r.FormValue("username"), r.FormValue("password") if username == "" || password == "" { return "", nil, errors.ErrInvalidRequest } userID, err := s.PasswordAuthorizationHandler(username, password) if err != nil { return "", nil, err } else if userID == "" { return "", nil, errors.ErrInvalidGrant } tgr.UserID = userID case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") case oauth2.Refreshing: tgr.Refresh = r.FormValue("refresh_token") tgr.Scope = r.FormValue("scope") if tgr.Refresh == "" { return "", nil, errors.ErrInvalidRequest } } return gt, tgr, nil } // CheckGrantType check allows grant type func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { for _, agt := range s.Config.AllowedGrantTypes { if agt == gt { return true } } return false } // GetAccessToken access token func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { if allowed := s.CheckGrantType(gt); !allowed { return nil, errors.ErrUnauthorizedClient } if fn := s.ClientAuthorizedHandler; fn != nil { allowed, err := fn(tgr.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } switch gt { case oauth2.AuthorizationCode: ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) if err != nil { switch err { case errors.ErrInvalidAuthorizeCode: return nil, errors.ErrInvalidGrant case errors.ErrInvalidClient: return nil, errors.ErrInvalidClient default: return nil, err } } return ti, nil case oauth2.PasswordCredentials, oauth2.ClientCredentials: if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAccessToken(ctx, gt, tgr) case oauth2.Refreshing: // check scope if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } allowed, err := scopeFn(tgr, rti.GetScope()) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } ti, err := s.Manager.RefreshAccessToken(ctx, tgr) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } return ti, nil } return nil, errors.ErrUnsupportedGrantType } // GetTokenData token data func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { data := map[string]interface{}{ "access_token": ti.GetAccess(), "token_type": s.Config.TokenType, "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } if scope := ti.GetScope(); scope != "" { data["scope"] = scope } if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } if fn := s.ExtensionFieldsHandler; fn != nil { ext := fn(ti) for k, v := range ext { if _, ok := data[k]; ok { continue } data[k] = v } } return data } // HandleTokenRequest token request handling func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() gt, tgr, err := s.ValidationTokenRequest(r) if err != nil { return s.tokenError(w, err) } ti, err := s.GetAccessToken(ctx, gt, tgr) if err != nil { return s.tokenError(w, err) } return s.token(w, s.GetTokenData(ti), nil) } // GetErrorData get error response data func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { var re errors.Response if v, ok := errors.Descriptions[err]; ok { re.Error = err re.Description = v re.StatusCode = errors.StatusCodes[err] } else { if fn := s.InternalErrorHandler; fn != nil { if v := fn(err); v != nil { re = *v } } if re.Error == nil { re.Error = errors.ErrServerError re.Description = errors.Descriptions[errors.ErrServerError] re.StatusCode = errors.StatusCodes[errors.ErrServerError] } } if fn := s.ResponseErrorHandler; fn != nil { fn(&re) } data := make(map[string]interface{}) if err := re.Error; err != nil { data["error"] = err.Error() } if v := re.ErrorCode; v != 0 { data["error_code"] = v } if v := re.Description; v != "" { data["error_description"] = v } if v := re.URI; v != "" { data["error_uri"] = v } statusCode := http.StatusInternalServerError if v := re.StatusCode; v > 0 { statusCode = v } return data, statusCode, re.Header } // BearerAuth parse bearer token func (s *Server) BearerAuth(r *http.Request) (string, bool) { auth := r.Header.Get("Authorization") prefix := "Bearer " token := "" if auth != "" && strings.HasPrefix(auth, prefix) { token = auth[len(prefix):] } else { token = r.FormValue("access_token") } return token, token != "" } // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() accessToken, ok := s.BearerAuth(r) if !ok { return nil, errors.ErrInvalidAccessToken } return s.Manager.LoadAccessToken(ctx, accessToken) } \ No newline at end of file diff --git a/server/server_config.go b/server/server_config.go index af0477a..d9b740d 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -1 +1,80 @@ -package server import ( "github.com/go-oauth2/oauth2/v4" ) // SetTokenType token type func (s *Server) SetTokenType(tokenType string) { s.Config.TokenType = tokenType } // SetAllowGetAccessRequest to allow GET requests for the token func (s *Server) SetAllowGetAccessRequest(allow bool) { s.Config.AllowGetAccessRequest = allow } // SetAllowedResponseType allow the authorization types func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) { s.Config.AllowedResponseTypes = types } // SetAllowedGrantType allow the grant types func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) { s.Config.AllowedGrantTypes = types } // SetClientInfoHandler get client info from request func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) { s.ClientInfoHandler = handler } // SetClientAuthorizedHandler check the client allows to use this authorization grant type func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) { s.ClientAuthorizedHandler = handler } // SetClientScopeHandler check the client allows to use scope func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) { s.ClientScopeHandler = handler } // SetUserAuthorizationHandler get user id from request authorization func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) { s.UserAuthorizationHandler = handler } // SetPasswordAuthorizationHandler get user id from username and password func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) { s.PasswordAuthorizationHandler = handler } // SetRefreshingScopeHandler check the scope of the refreshing token func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) { s.RefreshingScopeHandler = handler } // SetResponseErrorHandler response error handling func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) { s.ResponseErrorHandler = handler } // SetInternalErrorHandler internal error handling func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) { s.InternalErrorHandler = handler } // SetExtensionFieldsHandler in response to the access token with the extension of the field func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) { s.ExtensionFieldsHandler = handler } // SetAccessTokenExpHandler set expiration date for the access token func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) { s.AccessTokenExpHandler = handler } // SetAuthorizeScopeHandler set scope for the access token func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { s.AuthorizeScopeHandler = handler } \ No newline at end of file +package server + +import ( + "github.com/go-oauth2/oauth2/v4" +) + +// SetTokenType token type +func (s *Server) SetTokenType(tokenType string) { + s.Config.TokenType = tokenType +} + +// SetAllowGetAccessRequest to allow GET requests for the token +func (s *Server) SetAllowGetAccessRequest(allow bool) { + s.Config.AllowGetAccessRequest = allow +} + +// SetAllowedResponseType allow the authorization types +func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) { + s.Config.AllowedResponseTypes = types +} + +// SetAllowedGrantType allow the grant types +func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) { + s.Config.AllowedGrantTypes = types +} + +// SetClientInfoHandler get client info from request +func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) { + s.ClientInfoHandler = handler +} + +// SetClientAuthorizedHandler check the client allows to use this authorization grant type +func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) { + s.ClientAuthorizedHandler = handler +} + +// SetClientScopeHandler check the client allows to use scope +func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) { + s.ClientScopeHandler = handler +} + +// SetUserAuthorizationHandler get user id from request authorization +func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) { + s.UserAuthorizationHandler = handler +} + +// SetPasswordAuthorizationHandler get user id from username and password +func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) { + s.PasswordAuthorizationHandler = handler +} + +// SetRefreshingScopeHandler check the scope of the refreshing token +func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) { + s.RefreshingScopeHandler = handler +} + +// SetResponseErrorHandler response error handling +func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) { + s.ResponseErrorHandler = handler +} + +// SetInternalErrorHandler internal error handling +func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) { + s.InternalErrorHandler = handler +} + +// SetExtensionFieldsHandler in response to the access token with the extension of the field +func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) { + s.ExtensionFieldsHandler = handler +} + +// SetAccessTokenExpHandler set expiration date for the access token +func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) { + s.AccessTokenExpHandler = handler +} + +// SetAuthorizeScopeHandler set scope for the access token +func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { + s.AuthorizeScopeHandler = handler +} From 399b68cada52d4eca68f6cdc6bfce2a385207cce Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:41:54 +1000 Subject: [PATCH 8/9] LFCR --- .editorconfig | 2 +- manage/manager.go | 472 ++++++++++++++++++++++++++++++++++++++++- server/handler.go | 64 +++++- server/server.go | 531 +++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 1065 insertions(+), 4 deletions(-) diff --git a/.editorconfig b/.editorconfig index 6c376a8..74aa1f6 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,2 +1,2 @@ [*.go] -end_of_line = lf \ No newline at end of file +end_of_line = crlf \ No newline at end of file diff --git a/manage/manager.go b/manage/manager.go index 0ff1deb..f04ae50 100755 --- a/manage/manager.go +++ b/manage/manager.go @@ -1 +1,471 @@ -package manage import ( "context" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" "github.com/go-oauth2/oauth2/v4/generates" "github.com/go-oauth2/oauth2/v4/models" ) // NewDefaultManager create to default authorization management instance func NewDefaultManager() *Manager { m := NewManager() // default implementation m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) m.MapAccessGenerate(generates.NewAccessGenerate()) return m } // NewManager create to authorization management instance func NewManager() *Manager { return &Manager{ gtcfg: make(map[oauth2.GrantType]*Config), validateURI: DefaultValidateURI, } } // Manager provide authorization management type Manager struct { codeExp time.Duration gtcfg map[oauth2.GrantType]*Config rcfg *RefreshingConfig validateURI ValidateURIHandler authorizeGenerate oauth2.AuthorizeGenerate accessGenerate oauth2.AccessGenerate tokenStore oauth2.TokenStore clientStore oauth2.ClientStore } // get grant type config func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { if c, ok := m.gtcfg[gt]; ok && c != nil { return c } switch gt { case oauth2.AuthorizationCode: return DefaultAuthorizeCodeTokenCfg case oauth2.Implicit: return DefaultImplicitTokenCfg case oauth2.PasswordCredentials: return DefaultPasswordTokenCfg case oauth2.ClientCredentials: return DefaultClientTokenCfg } return &Config{} } // SetAuthorizeCodeExp set the authorization code expiration time func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { m.codeExp = exp } // SetAuthorizeCodeTokenCfg set the authorization code grant token config func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { m.gtcfg[oauth2.AuthorizationCode] = cfg } // SetImplicitTokenCfg set the implicit grant token config func (m *Manager) SetImplicitTokenCfg(cfg *Config) { m.gtcfg[oauth2.Implicit] = cfg } // SetPasswordTokenCfg set the password grant token config func (m *Manager) SetPasswordTokenCfg(cfg *Config) { m.gtcfg[oauth2.PasswordCredentials] = cfg } // SetClientTokenCfg set the client grant token config func (m *Manager) SetClientTokenCfg(cfg *Config) { m.gtcfg[oauth2.ClientCredentials] = cfg } // SetRefreshTokenCfg set the refreshing token config func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { m.rcfg = cfg } // SetValidateURIHandler set the validates that RedirectURI is contained in baseURI func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { m.validateURI = handler } // MapAuthorizeGenerate mapping the authorize code generate interface func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { m.authorizeGenerate = gen } // MapAccessGenerate mapping the access token generate interface func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { m.accessGenerate = gen } // MapClientStorage mapping the client store interface func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { m.clientStore = stor } // MustClientStorage mandatory mapping the client store interface func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { if err != nil { panic(err.Error()) } m.clientStore = stor } // MapTokenStorage mapping the token store interface func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { m.tokenStore = stor } // MustTokenStorage mandatory mapping the token store interface func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { if err != nil { panic(err) } m.tokenStore = stor } // GetClient get the client information func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { cli, err = m.clientStore.GetByID(ctx, clientID) if err != nil { return } else if cli == nil { err = errors.ErrInvalidClient } return } // GenerateAuthToken generate the authorization token(code) func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } else if tgr.RedirectURI != "" { if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { return nil, err } } ti := models.NewToken() ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) createAt := time.Now() td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: createAt, TokenInfo: ti, Request: tgr.Request, } switch rt { case oauth2.Code: codeExp := m.codeExp if codeExp == 0 { codeExp = DefaultCodeExp } ti.SetCodeCreateAt(createAt) ti.SetCodeExpiresIn(codeExp) if exp := tgr.AccessTokenExp; exp > 0 { ti.SetAccessExpiresIn(exp) } tv, err := m.authorizeGenerate.Token(ctx, td) if err != nil { return nil, err } ti.SetCode(tv) case oauth2.Token: // set access token expires icfg := m.grantConfig(oauth2.Implicit) aexp := icfg.AccessTokenExp if exp := tgr.AccessTokenExp; exp > 0 { aexp = exp } ti.SetAccessCreateAt(createAt) ti.SetAccessExpiresIn(aexp) if icfg.IsGenerateRefresh { ti.SetRefreshCreateAt(createAt) ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) } tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(tv) if rv != "" { ti.SetRefresh(rv) } } err = m.tokenStore.Create(ctx, ti) if err != nil { return nil, err } return ti, nil } // get authorization code data func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { ti, err := m.tokenStore.GetByCode(ctx, code) if err != nil { return nil, err } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { err = errors.ErrInvalidAuthorizeCode return nil, errors.ErrInvalidAuthorizeCode } return ti, nil } // delete authorization code data func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { return m.tokenStore.RemoveByCode(ctx, code) } // get and delete authorization code data func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { code := tgr.Code ti, err := m.getAuthorizationCode(ctx, code) if err != nil { return nil, err } else if ti.GetClientID() != tgr.ClientID { return nil, errors.ErrInvalidAuthorizeCode } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { return nil, errors.ErrInvalidAuthorizeCode } err = m.delAuthorizationCode(ctx, code) if err != nil { return nil, err } return ti, nil } // GenerateAccessToken generate the access token func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { if !cliPass.VerifyPassword(tgr.ClientSecret) { return nil, errors.ErrInvalidClient } } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { return nil, errors.ErrInvalidClient } if tgr.RedirectURI != "" { if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { return nil, err } } if gt == oauth2.AuthorizationCode { ti, err := m.getAndDelAuthorizationCode(ctx, tgr) if err != nil { return nil, err } tgr.UserID = ti.GetUserID() tgr.Scope = ti.GetScope() if exp := ti.GetAccessExpiresIn(); exp > 0 { tgr.AccessTokenExp = exp } } ti := models.NewToken() ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) createAt := time.Now() ti.SetAccessCreateAt(createAt) // set access token expires gcfg := m.grantConfig(gt) aexp := gcfg.AccessTokenExp if exp := tgr.AccessTokenExp; exp > 0 { aexp = exp } ti.SetAccessExpiresIn(aexp) if gcfg.IsGenerateRefresh { ti.SetRefreshCreateAt(createAt) ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) } td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: createAt, TokenInfo: ti, Request: tgr.Request, } av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(av) if rv != "" { ti.SetRefresh(rv) } err = m.tokenStore.Create(ctx, ti) if err != nil { return nil, err } return ti, nil } // RefreshAccessToken refreshing an access token func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } else if tgr.ClientSecret != cli.GetSecret() { return nil, errors.ErrInvalidClient } ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { return nil, err } else if ti.GetClientID() != tgr.ClientID { return nil, errors.ErrInvalidRefreshToken } oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() td := &oauth2.GenerateBasic{ Client: cli, UserID: ti.GetUserID(), CreateAt: time.Now(), TokenInfo: ti, Request: tgr.Request, } rcfg := DefaultRefreshTokenCfg if v := m.rcfg; v != nil { rcfg = v } ti.SetAccessCreateAt(td.CreateAt) if v := rcfg.AccessTokenExp; v > 0 { ti.SetAccessExpiresIn(v) } if v := rcfg.RefreshTokenExp; v > 0 { ti.SetRefreshExpiresIn(v) } if rcfg.IsResetRefreshTime { ti.SetRefreshCreateAt(td.CreateAt) } if scope := tgr.Scope; scope != "" { ti.SetScope(scope) } tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(tv) if rv != "" { ti.SetRefresh(rv) } if err := m.tokenStore.Create(ctx, ti); err != nil { return nil, err } if rcfg.IsRemoveAccess { // remove the old access token if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { return nil, err } } if rcfg.IsRemoveRefreshing && rv != "" { // remove the old refresh token if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { return nil, err } } if rv == "" { ti.SetRefresh("") ti.SetRefreshCreateAt(time.Now()) ti.SetRefreshExpiresIn(0) } return ti, nil } // RemoveAccessToken use the access token to delete the token information func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { if access == "" { return errors.ErrInvalidAccessToken } return m.tokenStore.RemoveByAccess(ctx, access) } // RemoveRefreshToken use the refresh token to delete the token information func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { if refresh == "" { return errors.ErrInvalidAccessToken } return m.tokenStore.RemoveByRefresh(ctx, refresh) } // LoadAccessToken according to the access token for corresponding token information func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { if access == "" { return nil, errors.ErrInvalidAccessToken } ct := time.Now() ti, err := m.tokenStore.GetByAccess(ctx, access) if err != nil { return nil, err } else if ti == nil || ti.GetAccess() != access { return nil, errors.ErrInvalidAccessToken } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { return nil, errors.ErrExpiredRefreshToken } else if ti.GetAccessExpiresIn() != 0 && ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { return nil, errors.ErrExpiredAccessToken } return ti, nil } // LoadRefreshToken according to the refresh token for corresponding token information func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { if refresh == "" { return nil, errors.ErrInvalidRefreshToken } ti, err := m.tokenStore.GetByRefresh(ctx, refresh) if err != nil { return nil, err } else if ti == nil || ti.GetRefresh() != refresh { return nil, errors.ErrInvalidRefreshToken } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { return nil, errors.ErrExpiredRefreshToken } return ti, nil } \ No newline at end of file +package manage + +import ( + "context" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" + "github.com/go-oauth2/oauth2/v4/generates" + "github.com/go-oauth2/oauth2/v4/models" +) + +// NewDefaultManager create to default authorization management instance +func NewDefaultManager() *Manager { + m := NewManager() + // default implementation + m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) + m.MapAccessGenerate(generates.NewAccessGenerate()) + + return m +} + +// NewManager create to authorization management instance +func NewManager() *Manager { + return &Manager{ + gtcfg: make(map[oauth2.GrantType]*Config), + validateURI: DefaultValidateURI, + } +} + +// Manager provide authorization management +type Manager struct { + codeExp time.Duration + gtcfg map[oauth2.GrantType]*Config + rcfg *RefreshingConfig + validateURI ValidateURIHandler + authorizeGenerate oauth2.AuthorizeGenerate + accessGenerate oauth2.AccessGenerate + tokenStore oauth2.TokenStore + clientStore oauth2.ClientStore +} + +// get grant type config +func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { + if c, ok := m.gtcfg[gt]; ok && c != nil { + return c + } + switch gt { + case oauth2.AuthorizationCode: + return DefaultAuthorizeCodeTokenCfg + case oauth2.Implicit: + return DefaultImplicitTokenCfg + case oauth2.PasswordCredentials: + return DefaultPasswordTokenCfg + case oauth2.ClientCredentials: + return DefaultClientTokenCfg + } + return &Config{} +} + +// SetAuthorizeCodeExp set the authorization code expiration time +func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { + m.codeExp = exp +} + +// SetAuthorizeCodeTokenCfg set the authorization code grant token config +func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { + m.gtcfg[oauth2.AuthorizationCode] = cfg +} + +// SetImplicitTokenCfg set the implicit grant token config +func (m *Manager) SetImplicitTokenCfg(cfg *Config) { + m.gtcfg[oauth2.Implicit] = cfg +} + +// SetPasswordTokenCfg set the password grant token config +func (m *Manager) SetPasswordTokenCfg(cfg *Config) { + m.gtcfg[oauth2.PasswordCredentials] = cfg +} + +// SetClientTokenCfg set the client grant token config +func (m *Manager) SetClientTokenCfg(cfg *Config) { + m.gtcfg[oauth2.ClientCredentials] = cfg +} + +// SetRefreshTokenCfg set the refreshing token config +func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { + m.rcfg = cfg +} + +// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI +func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { + m.validateURI = handler +} + +// MapAuthorizeGenerate mapping the authorize code generate interface +func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { + m.authorizeGenerate = gen +} + +// MapAccessGenerate mapping the access token generate interface +func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { + m.accessGenerate = gen +} + +// MapClientStorage mapping the client store interface +func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { + m.clientStore = stor +} + +// MustClientStorage mandatory mapping the client store interface +func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { + if err != nil { + panic(err.Error()) + } + m.clientStore = stor +} + +// MapTokenStorage mapping the token store interface +func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { + m.tokenStore = stor +} + +// MustTokenStorage mandatory mapping the token store interface +func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { + if err != nil { + panic(err) + } + m.tokenStore = stor +} + +// GetClient get the client information +func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { + cli, err = m.clientStore.GetByID(ctx, clientID) + if err != nil { + return + } else if cli == nil { + err = errors.ErrInvalidClient + } + return +} + +// GenerateAuthToken generate the authorization token(code) +func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } else if tgr.RedirectURI != "" { + if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { + return nil, err + } + } + + ti := models.NewToken() + ti.SetClientID(tgr.ClientID) + ti.SetUserID(tgr.UserID) + ti.SetRedirectURI(tgr.RedirectURI) + ti.SetScope(tgr.Scope) + + createAt := time.Now() + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: tgr.UserID, + CreateAt: createAt, + TokenInfo: ti, + Request: tgr.Request, + } + switch rt { + case oauth2.Code: + codeExp := m.codeExp + if codeExp == 0 { + codeExp = DefaultCodeExp + } + ti.SetCodeCreateAt(createAt) + ti.SetCodeExpiresIn(codeExp) + if exp := tgr.AccessTokenExp; exp > 0 { + ti.SetAccessExpiresIn(exp) + } + + tv, err := m.authorizeGenerate.Token(ctx, td) + if err != nil { + return nil, err + } + ti.SetCode(tv) + case oauth2.Token: + // set access token expires + icfg := m.grantConfig(oauth2.Implicit) + aexp := icfg.AccessTokenExp + if exp := tgr.AccessTokenExp; exp > 0 { + aexp = exp + } + ti.SetAccessCreateAt(createAt) + ti.SetAccessExpiresIn(aexp) + + if icfg.IsGenerateRefresh { + ti.SetRefreshCreateAt(createAt) + ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) + } + + tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + ti.SetAccess(tv) + + if rv != "" { + ti.SetRefresh(rv) + } + } + + err = m.tokenStore.Create(ctx, ti) + if err != nil { + return nil, err + } + return ti, nil +} + +// get authorization code data +func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { + ti, err := m.tokenStore.GetByCode(ctx, code) + if err != nil { + return nil, err + } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { + err = errors.ErrInvalidAuthorizeCode + return nil, errors.ErrInvalidAuthorizeCode + } + return ti, nil +} + +// delete authorization code data +func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { + return m.tokenStore.RemoveByCode(ctx, code) +} + +// get and delete authorization code data +func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + code := tgr.Code + ti, err := m.getAuthorizationCode(ctx, code) + if err != nil { + return nil, err + } else if ti.GetClientID() != tgr.ClientID { + return nil, errors.ErrInvalidAuthorizeCode + } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { + return nil, errors.ErrInvalidAuthorizeCode + } + + err = m.delAuthorizationCode(ctx, code) + if err != nil { + return nil, err + } + return ti, nil +} + +// GenerateAccessToken generate the access token +func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } + if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { + if !cliPass.VerifyPassword(tgr.ClientSecret) { + return nil, errors.ErrInvalidClient + } + } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { + return nil, errors.ErrInvalidClient + } + if tgr.RedirectURI != "" { + if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { + return nil, err + } + } + + if gt == oauth2.AuthorizationCode { + ti, err := m.getAndDelAuthorizationCode(ctx, tgr) + if err != nil { + return nil, err + } + tgr.UserID = ti.GetUserID() + tgr.Scope = ti.GetScope() + if exp := ti.GetAccessExpiresIn(); exp > 0 { + tgr.AccessTokenExp = exp + } + } + + ti := models.NewToken() + ti.SetClientID(tgr.ClientID) + ti.SetUserID(tgr.UserID) + ti.SetRedirectURI(tgr.RedirectURI) + ti.SetScope(tgr.Scope) + + createAt := time.Now() + ti.SetAccessCreateAt(createAt) + + // set access token expires + gcfg := m.grantConfig(gt) + aexp := gcfg.AccessTokenExp + if exp := tgr.AccessTokenExp; exp > 0 { + aexp = exp + } + ti.SetAccessExpiresIn(aexp) + if gcfg.IsGenerateRefresh { + ti.SetRefreshCreateAt(createAt) + ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) + } + + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: tgr.UserID, + CreateAt: createAt, + TokenInfo: ti, + Request: tgr.Request, + } + + av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + ti.SetAccess(av) + + if rv != "" { + ti.SetRefresh(rv) + } + + err = m.tokenStore.Create(ctx, ti) + if err != nil { + return nil, err + } + + return ti, nil +} + +// RefreshAccessToken refreshing an access token +func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } else if tgr.ClientSecret != cli.GetSecret() { + return nil, errors.ErrInvalidClient + } + + ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + return nil, err + } else if ti.GetClientID() != tgr.ClientID { + return nil, errors.ErrInvalidRefreshToken + } + + oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() + + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: ti.GetUserID(), + CreateAt: time.Now(), + TokenInfo: ti, + Request: tgr.Request, + } + + rcfg := DefaultRefreshTokenCfg + if v := m.rcfg; v != nil { + rcfg = v + } + + ti.SetAccessCreateAt(td.CreateAt) + if v := rcfg.AccessTokenExp; v > 0 { + ti.SetAccessExpiresIn(v) + } + + if v := rcfg.RefreshTokenExp; v > 0 { + ti.SetRefreshExpiresIn(v) + } + + if rcfg.IsResetRefreshTime { + ti.SetRefreshCreateAt(td.CreateAt) + } + + if scope := tgr.Scope; scope != "" { + ti.SetScope(scope) + } + + tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + + ti.SetAccess(tv) + if rv != "" { + ti.SetRefresh(rv) + } + + if err := m.tokenStore.Create(ctx, ti); err != nil { + return nil, err + } + + if rcfg.IsRemoveAccess { + // remove the old access token + if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { + return nil, err + } + } + + if rcfg.IsRemoveRefreshing && rv != "" { + // remove the old refresh token + if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { + return nil, err + } + } + + if rv == "" { + ti.SetRefresh("") + ti.SetRefreshCreateAt(time.Now()) + ti.SetRefreshExpiresIn(0) + } + + return ti, nil +} + +// RemoveAccessToken use the access token to delete the token information +func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { + if access == "" { + return errors.ErrInvalidAccessToken + } + return m.tokenStore.RemoveByAccess(ctx, access) +} + +// RemoveRefreshToken use the refresh token to delete the token information +func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { + if refresh == "" { + return errors.ErrInvalidAccessToken + } + return m.tokenStore.RemoveByRefresh(ctx, refresh) +} + +// LoadAccessToken according to the access token for corresponding token information +func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { + if access == "" { + return nil, errors.ErrInvalidAccessToken + } + + ct := time.Now() + ti, err := m.tokenStore.GetByAccess(ctx, access) + if err != nil { + return nil, err + } else if ti == nil || ti.GetAccess() != access { + return nil, errors.ErrInvalidAccessToken + } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && + ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { + return nil, errors.ErrExpiredRefreshToken + } else if ti.GetAccessExpiresIn() != 0 && + ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { + return nil, errors.ErrExpiredAccessToken + } + return ti, nil +} + +// LoadRefreshToken according to the refresh token for corresponding token information +func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { + if refresh == "" { + return nil, errors.ErrInvalidRefreshToken + } + + ti, err := m.tokenStore.GetByRefresh(ctx, refresh) + if err != nil { + return nil, err + } else if ti == nil || ti.GetRefresh() != refresh { + return nil, errors.ErrInvalidRefreshToken + } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire + ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { + return nil, errors.ErrExpiredRefreshToken + } + return ti, nil +} diff --git a/server/handler.go b/server/handler.go index 0474605..df66137 100755 --- a/server/handler.go +++ b/server/handler.go @@ -1 +1,63 @@ -package server import ( "net/http" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" ) type ( // ClientInfoHandler get client info from request ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) // ClientAuthorizedHandler check the client allows to use this authorization grant type ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) // ClientScopeHandler check the client allows to use scope ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) // UserAuthorizationHandler get user id from request authorization UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) // PasswordAuthorizationHandler get user id from username and password PasswordAuthorizationHandler func(username, password string) (userID string, err error) // RefreshingScopeHandler check the scope of the refreshing token RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) // ResponseErrorHandler response error handing ResponseErrorHandler func(re *errors.Response) // InternalErrorHandler internal error handing InternalErrorHandler func(err error) (re *errors.Response) // AuthorizeScopeHandler set the authorized scope AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) // AccessTokenExpHandler set expiration date for the access token AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) // ExtensionFieldsHandler in response to the access token with the extension of the field ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) ) // ClientFormHandler get client data from form func ClientFormHandler(r *http.Request) (string, string, error) { clientID := r.Form.Get("client_id") if clientID == "" { return "", "", errors.ErrInvalidClient } clientSecret := r.Form.Get("client_secret") return clientID, clientSecret, nil } // ClientBasicHandler get client data from basic authorization func ClientBasicHandler(r *http.Request) (string, string, error) { username, password, ok := r.BasicAuth() if !ok { return "", "", errors.ErrInvalidClient } return username, password, nil } \ No newline at end of file +package server + +import ( + "net/http" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +type ( + // ClientInfoHandler get client info from request + ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) + + // ClientAuthorizedHandler check the client allows to use this authorization grant type + ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) + + // ClientScopeHandler check the client allows to use scope + ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) + + // UserAuthorizationHandler get user id from request authorization + UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) + + // PasswordAuthorizationHandler get user id from username and password + PasswordAuthorizationHandler func(username, password string) (userID string, err error) + + // RefreshingScopeHandler check the scope of the refreshing token + RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) + + // ResponseErrorHandler response error handing + ResponseErrorHandler func(re *errors.Response) + + // InternalErrorHandler internal error handing + InternalErrorHandler func(err error) (re *errors.Response) + + // AuthorizeScopeHandler set the authorized scope + AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) + + // AccessTokenExpHandler set expiration date for the access token + AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) + + // ExtensionFieldsHandler in response to the access token with the extension of the field + ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) +) + +// ClientFormHandler get client data from form +func ClientFormHandler(r *http.Request) (string, string, error) { + clientID := r.Form.Get("client_id") + if clientID == "" { + return "", "", errors.ErrInvalidClient + } + clientSecret := r.Form.Get("client_secret") + return clientID, clientSecret, nil +} + +// ClientBasicHandler get client data from basic authorization +func ClientBasicHandler(r *http.Request) (string, string, error) { + username, password, ok := r.BasicAuth() + if !ok { + return "", "", errors.ErrInvalidClient + } + return username, password, nil +} diff --git a/server/server.go b/server/server.go index effc8f6..6887431 100755 --- a/server/server.go +++ b/server/server.go @@ -1 +1,530 @@ -package server import ( "context" "encoding/json" "fmt" "net/http" "net/url" "strings" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/errors" ) // NewDefaultServer create a default authorization server func NewDefaultServer(manager oauth2.Manager) *Server { return NewServer(NewConfig(), manager) } // NewServer create authorization server func NewServer(cfg *Config, manager oauth2.Manager) *Server { srv := &Server{ Config: cfg, Manager: manager, } // default handler srv.ClientInfoHandler = ClientBasicHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv } // Server Provide authorization server type Server struct { Config *Config Manager oauth2.Manager ClientInfoHandler ClientInfoHandler ClientAuthorizedHandler ClientAuthorizedHandler ClientScopeHandler ClientScopeHandler UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { uri, err := s.GetRedirectURI(req, data) if err != nil { return err } w.Header().Set("Location", uri) w.WriteHeader(302) return nil } func (s *Server) tokenError(w http.ResponseWriter, err error) error { data, statusCode, header := s.GetErrorData(err) return s.token(w, data, header, statusCode) } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") for key := range header { w.Header().Set(key, header.Get(key)) } status := http.StatusOK if len(statusCode) > 0 && statusCode[0] > 0 { status = statusCode[0] } w.WriteHeader(status) return json.NewEncoder(w).Encode(data) } // GetRedirectURI get redirect uri func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { u, err := url.Parse(req.RedirectURI) if err != nil { return "", err } q := u.Query() if req.State != "" { q.Set("state", req.State) } for k, v := range data { q.Set(k, fmt.Sprint(v)) } switch req.ResponseType { case oauth2.Code: u.RawQuery = q.Encode() case oauth2.Token: u.RawQuery = "" fragment, err := url.QueryUnescape(q.Encode()) if err != nil { return "", err } u.Fragment = fragment } return u.String(), nil } // CheckResponseType check allows response type func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { for _, art := range s.Config.AllowedResponseTypes { if art == rt { return true } } return false } // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") clientID := r.FormValue("client_id") if !(r.Method == "GET" || r.Method == "POST") || clientID == "" { return nil, errors.ErrInvalidRequest } resType := oauth2.ResponseType(r.FormValue("response_type")) if resType.String() == "" { return nil, errors.ErrUnsupportedResponseType } else if allowed := s.CheckResponseType(resType); !allowed { return nil, errors.ErrUnauthorizedClient } req := &AuthorizeRequest{ RedirectURI: redirectURI, ResponseType: resType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), Request: r, } return req, nil } // GetAuthorizeToken get authorization token(code) func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { // check the client allows the grant type if fn := s.ClientAuthorizedHandler; fn != nil { gt := oauth2.AuthorizationCode if req.ResponseType == oauth2.Token { gt = oauth2.Implicit } allowed, err := fn(req.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } tgr := &oauth2.TokenGenerateRequest{ ClientID: req.ClientID, UserID: req.UserID, RedirectURI: req.RedirectURI, Scope: req.Scope, AccessTokenExp: req.AccessTokenExp, Request: req.Request, } // check the client allows the authorized scope if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) } // GetAuthorizeData get authorization response data func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { if rt == oauth2.Code { return map[string]interface{}{ "code": ti.GetCode(), } } return s.GetTokenData(ti) } // HandleAuthorizeRequest the authorization request handling func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() req, err := s.ValidationAuthorizeRequest(r) if err != nil { return s.redirectError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { return s.redirectError(w, req, err) } else if userID == "" { return nil } req.UserID = userID // specify the scope of authorization if fn := s.AuthorizeScopeHandler; fn != nil { scope, err := fn(w, r) if err != nil { return err } else if scope != "" { req.Scope = scope } } // specify the expiration time of access token if fn := s.AccessTokenExpHandler; fn != nil { exp, err := fn(w, r) if err != nil { return err } req.AccessTokenExp = exp } ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { return s.redirectError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. if req.RedirectURI == "" { client, err := s.Manager.GetClient(ctx, req.ClientID) if err != nil { return err } req.RedirectURI = client.GetDomain() } return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) } // ValidationTokenRequest the token request validation func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { if v := r.Method; !(v == "POST" || (s.Config.AllowGetAccessRequest && v == "GET")) { return "", nil, errors.ErrInvalidRequest } gt := oauth2.GrantType(r.FormValue("grant_type")) if gt.String() == "" { return "", nil, errors.ErrUnsupportedGrantType } clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return "", nil, err } tgr := &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, Request: r, } switch gt { case oauth2.AuthorizationCode: tgr.RedirectURI = r.FormValue("redirect_uri") tgr.Code = r.FormValue("code") if tgr.RedirectURI == "" || tgr.Code == "" { return "", nil, errors.ErrInvalidRequest } case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") username, password := r.FormValue("username"), r.FormValue("password") if username == "" || password == "" { return "", nil, errors.ErrInvalidRequest } userID, err := s.PasswordAuthorizationHandler(username, password) if err != nil { return "", nil, err } else if userID == "" { return "", nil, errors.ErrInvalidGrant } tgr.UserID = userID case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") case oauth2.Refreshing: tgr.Refresh = r.FormValue("refresh_token") tgr.Scope = r.FormValue("scope") if tgr.Refresh == "" { return "", nil, errors.ErrInvalidRequest } } return gt, tgr, nil } // CheckGrantType check allows grant type func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { for _, agt := range s.Config.AllowedGrantTypes { if agt == gt { return true } } return false } // GetAccessToken access token func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { if allowed := s.CheckGrantType(gt); !allowed { return nil, errors.ErrUnauthorizedClient } if fn := s.ClientAuthorizedHandler; fn != nil { allowed, err := fn(tgr.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } switch gt { case oauth2.AuthorizationCode: ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) if err != nil { switch err { case errors.ErrInvalidAuthorizeCode: return nil, errors.ErrInvalidGrant case errors.ErrInvalidClient: return nil, errors.ErrInvalidClient default: return nil, err } } return ti, nil case oauth2.PasswordCredentials, oauth2.ClientCredentials: if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAccessToken(ctx, gt, tgr) case oauth2.Refreshing: // check scope if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } allowed, err := scopeFn(tgr, rti.GetScope()) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } ti, err := s.Manager.RefreshAccessToken(ctx, tgr) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } return ti, nil } return nil, errors.ErrUnsupportedGrantType } // GetTokenData token data func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { data := map[string]interface{}{ "access_token": ti.GetAccess(), "token_type": s.Config.TokenType, "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } if scope := ti.GetScope(); scope != "" { data["scope"] = scope } if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } if fn := s.ExtensionFieldsHandler; fn != nil { ext := fn(ti) for k, v := range ext { if _, ok := data[k]; ok { continue } data[k] = v } } return data } // HandleTokenRequest token request handling func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() gt, tgr, err := s.ValidationTokenRequest(r) if err != nil { return s.tokenError(w, err) } ti, err := s.GetAccessToken(ctx, gt, tgr) if err != nil { return s.tokenError(w, err) } return s.token(w, s.GetTokenData(ti), nil) } // GetErrorData get error response data func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { var re errors.Response if v, ok := errors.Descriptions[err]; ok { re.Error = err re.Description = v re.StatusCode = errors.StatusCodes[err] } else { if fn := s.InternalErrorHandler; fn != nil { if v := fn(err); v != nil { re = *v } } if re.Error == nil { re.Error = errors.ErrServerError re.Description = errors.Descriptions[errors.ErrServerError] re.StatusCode = errors.StatusCodes[errors.ErrServerError] } } if fn := s.ResponseErrorHandler; fn != nil { fn(&re) } data := make(map[string]interface{}) if err := re.Error; err != nil { data["error"] = err.Error() } if v := re.ErrorCode; v != 0 { data["error_code"] = v } if v := re.Description; v != "" { data["error_description"] = v } if v := re.URI; v != "" { data["error_uri"] = v } statusCode := http.StatusInternalServerError if v := re.StatusCode; v > 0 { statusCode = v } return data, statusCode, re.Header } // BearerAuth parse bearer token func (s *Server) BearerAuth(r *http.Request) (string, bool) { auth := r.Header.Get("Authorization") prefix := "Bearer " token := "" if auth != "" && strings.HasPrefix(auth, prefix) { token = auth[len(prefix):] } else { token = r.FormValue("access_token") } return token, token != "" } // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() accessToken, ok := s.BearerAuth(r) if !ok { return nil, errors.ErrInvalidAccessToken } return s.Manager.LoadAccessToken(ctx, accessToken) } \ No newline at end of file +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +// NewDefaultServer create a default authorization server +func NewDefaultServer(manager oauth2.Manager) *Server { + return NewServer(NewConfig(), manager) +} + +// NewServer create authorization server +func NewServer(cfg *Config, manager oauth2.Manager) *Server { + srv := &Server{ + Config: cfg, + Manager: manager, + } + + // default handler + srv.ClientInfoHandler = ClientBasicHandler + + srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { + return "", errors.ErrAccessDenied + } + + srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { + return "", errors.ErrAccessDenied + } + return srv +} + +// Server Provide authorization server +type Server struct { + Config *Config + Manager oauth2.Manager + ClientInfoHandler ClientInfoHandler + ClientAuthorizedHandler ClientAuthorizedHandler + ClientScopeHandler ClientScopeHandler + UserAuthorizationHandler UserAuthorizationHandler + PasswordAuthorizationHandler PasswordAuthorizationHandler + RefreshingScopeHandler RefreshingScopeHandler + ResponseErrorHandler ResponseErrorHandler + InternalErrorHandler InternalErrorHandler + ExtensionFieldsHandler ExtensionFieldsHandler + AccessTokenExpHandler AccessTokenExpHandler + AuthorizeScopeHandler AuthorizeScopeHandler +} + +func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { + if req == nil { + return err + } + data, _, _ := s.GetErrorData(err) + return s.redirect(w, req, data) +} + +func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { + uri, err := s.GetRedirectURI(req, data) + if err != nil { + return err + } + + w.Header().Set("Location", uri) + w.WriteHeader(302) + return nil +} + +func (s *Server) tokenError(w http.ResponseWriter, err error) error { + data, statusCode, header := s.GetErrorData(err) + return s.token(w, data, header, statusCode) +} + +func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { + w.Header().Set("Content-Type", "application/json;charset=UTF-8") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + + for key := range header { + w.Header().Set(key, header.Get(key)) + } + + status := http.StatusOK + if len(statusCode) > 0 && statusCode[0] > 0 { + status = statusCode[0] + } + + w.WriteHeader(status) + return json.NewEncoder(w).Encode(data) +} + +// GetRedirectURI get redirect uri +func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { + u, err := url.Parse(req.RedirectURI) + if err != nil { + return "", err + } + + q := u.Query() + if req.State != "" { + q.Set("state", req.State) + } + + for k, v := range data { + q.Set(k, fmt.Sprint(v)) + } + + switch req.ResponseType { + case oauth2.Code: + u.RawQuery = q.Encode() + case oauth2.Token: + u.RawQuery = "" + fragment, err := url.QueryUnescape(q.Encode()) + if err != nil { + return "", err + } + u.Fragment = fragment + } + + return u.String(), nil +} + +// CheckResponseType check allows response type +func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { + for _, art := range s.Config.AllowedResponseTypes { + if art == rt { + return true + } + } + return false +} + +// ValidationAuthorizeRequest the authorization request validation +func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + if !(r.Method == "GET" || r.Method == "POST") || + clientID == "" { + return nil, errors.ErrInvalidRequest + } + + resType := oauth2.ResponseType(r.FormValue("response_type")) + if resType.String() == "" { + return nil, errors.ErrUnsupportedResponseType + } else if allowed := s.CheckResponseType(resType); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + req := &AuthorizeRequest{ + RedirectURI: redirectURI, + ResponseType: resType, + ClientID: clientID, + State: r.FormValue("state"), + Scope: r.FormValue("scope"), + Request: r, + } + return req, nil +} + +// GetAuthorizeToken get authorization token(code) +func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { + // check the client allows the grant type + if fn := s.ClientAuthorizedHandler; fn != nil { + gt := oauth2.AuthorizationCode + if req.ResponseType == oauth2.Token { + gt = oauth2.Implicit + } + + allowed, err := fn(req.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + } + + // check the client allows the authorized scope + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) +} + +// GetAuthorizeData get authorization response data +func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { + if rt == oauth2.Code { + return map[string]interface{}{ + "code": ti.GetCode(), + } + } + return s.GetTokenData(ti) +} + +// HandleAuthorizeRequest the authorization request handling +func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + req, err := s.ValidationAuthorizeRequest(r) + if err != nil { + return s.redirectError(w, req, err) + } + + // user authorization + userID, err := s.UserAuthorizationHandler(w, r) + if err != nil { + return s.redirectError(w, req, err) + } else if userID == "" { + return nil + } + req.UserID = userID + + // specify the scope of authorization + if fn := s.AuthorizeScopeHandler; fn != nil { + scope, err := fn(w, r) + if err != nil { + return err + } else if scope != "" { + req.Scope = scope + } + } + + // specify the expiration time of access token + if fn := s.AccessTokenExpHandler; fn != nil { + exp, err := fn(w, r) + if err != nil { + return err + } + req.AccessTokenExp = exp + } + + ti, err := s.GetAuthorizeToken(ctx, req) + if err != nil { + return s.redirectError(w, req, err) + } + + // If the redirect URI is empty, the default domain provided by the client is used. + if req.RedirectURI == "" { + client, err := s.Manager.GetClient(ctx, req.ClientID) + if err != nil { + return err + } + req.RedirectURI = client.GetDomain() + } + + return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) +} + +// ValidationTokenRequest the token request validation +func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { + if v := r.Method; !(v == "POST" || + (s.Config.AllowGetAccessRequest && v == "GET")) { + return "", nil, errors.ErrInvalidRequest + } + + gt := oauth2.GrantType(r.FormValue("grant_type")) + if gt.String() == "" { + return "", nil, errors.ErrUnsupportedGrantType + } + + clientID, clientSecret, err := s.ClientInfoHandler(r) + if err != nil { + return "", nil, err + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + Request: r, + } + + switch gt { + case oauth2.AuthorizationCode: + tgr.RedirectURI = r.FormValue("redirect_uri") + tgr.Code = r.FormValue("code") + if tgr.RedirectURI == "" || + tgr.Code == "" { + return "", nil, errors.ErrInvalidRequest + } + case oauth2.PasswordCredentials: + tgr.Scope = r.FormValue("scope") + username, password := r.FormValue("username"), r.FormValue("password") + if username == "" || password == "" { + return "", nil, errors.ErrInvalidRequest + } + + userID, err := s.PasswordAuthorizationHandler(username, password) + if err != nil { + return "", nil, err + } else if userID == "" { + return "", nil, errors.ErrInvalidGrant + } + tgr.UserID = userID + case oauth2.ClientCredentials: + tgr.Scope = r.FormValue("scope") + case oauth2.Refreshing: + tgr.Refresh = r.FormValue("refresh_token") + tgr.Scope = r.FormValue("scope") + if tgr.Refresh == "" { + return "", nil, errors.ErrInvalidRequest + } + } + return gt, tgr, nil +} + +// CheckGrantType check allows grant type +func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { + for _, agt := range s.Config.AllowedGrantTypes { + if agt == gt { + return true + } + } + return false +} + +// GetAccessToken access token +func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + if allowed := s.CheckGrantType(gt); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + if fn := s.ClientAuthorizedHandler; fn != nil { + allowed, err := fn(tgr.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + switch gt { + case oauth2.AuthorizationCode: + ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) + if err != nil { + switch err { + case errors.ErrInvalidAuthorizeCode: + return nil, errors.ErrInvalidGrant + case errors.ErrInvalidClient: + return nil, errors.ErrInvalidClient + default: + return nil, err + } + } + return ti, nil + case oauth2.PasswordCredentials, oauth2.ClientCredentials: + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + return s.Manager.GenerateAccessToken(ctx, gt, tgr) + case oauth2.Refreshing: + // check scope + if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { + rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + + allowed, err := scopeFn(tgr, rti.GetScope()) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + ti, err := s.Manager.RefreshAccessToken(ctx, tgr) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + return ti, nil + } + + return nil, errors.ErrUnsupportedGrantType +} + +// GetTokenData token data +func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { + data := map[string]interface{}{ + "access_token": ti.GetAccess(), + "token_type": s.Config.TokenType, + "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), + } + + if scope := ti.GetScope(); scope != "" { + data["scope"] = scope + } + + if refresh := ti.GetRefresh(); refresh != "" { + data["refresh_token"] = refresh + } + + if fn := s.ExtensionFieldsHandler; fn != nil { + ext := fn(ti) + for k, v := range ext { + if _, ok := data[k]; ok { + continue + } + data[k] = v + } + } + return data +} + +// HandleTokenRequest token request handling +func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + gt, tgr, err := s.ValidationTokenRequest(r) + if err != nil { + return s.tokenError(w, err) + } + + ti, err := s.GetAccessToken(ctx, gt, tgr) + if err != nil { + return s.tokenError(w, err) + } + + return s.token(w, s.GetTokenData(ti), nil) +} + +// GetErrorData get error response data +func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { + var re errors.Response + if v, ok := errors.Descriptions[err]; ok { + re.Error = err + re.Description = v + re.StatusCode = errors.StatusCodes[err] + } else { + if fn := s.InternalErrorHandler; fn != nil { + if v := fn(err); v != nil { + re = *v + } + } + + if re.Error == nil { + re.Error = errors.ErrServerError + re.Description = errors.Descriptions[errors.ErrServerError] + re.StatusCode = errors.StatusCodes[errors.ErrServerError] + } + } + + if fn := s.ResponseErrorHandler; fn != nil { + fn(&re) + } + + data := make(map[string]interface{}) + if err := re.Error; err != nil { + data["error"] = err.Error() + } + + if v := re.ErrorCode; v != 0 { + data["error_code"] = v + } + + if v := re.Description; v != "" { + data["error_description"] = v + } + + if v := re.URI; v != "" { + data["error_uri"] = v + } + + statusCode := http.StatusInternalServerError + if v := re.StatusCode; v > 0 { + statusCode = v + } + + return data, statusCode, re.Header +} + +// BearerAuth parse bearer token +func (s *Server) BearerAuth(r *http.Request) (string, bool) { + auth := r.Header.Get("Authorization") + prefix := "Bearer " + token := "" + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +// ValidationBearerToken validation the bearer tokens +// https://tools.ietf.org/html/rfc6750 +func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { + ctx := r.Context() + + accessToken, ok := s.BearerAuth(r) + if !ok { + return nil, errors.ErrInvalidAccessToken + } + + return s.Manager.LoadAccessToken(ctx, accessToken) +} From 219ddb94a91ac669ad39b9607e22136764edbfde Mon Sep 17 00:00:00 2001 From: arran ubels Date: Tue, 21 Jul 2020 15:53:03 +1000 Subject: [PATCH 9/9] Git was causing the issue --- manage/manager.go | 942 ++++++++++++++++++++-------------------- server/handler.go | 126 +++--- server/server.go | 1060 ++++++++++++++++++++++----------------------- 3 files changed, 1064 insertions(+), 1064 deletions(-) diff --git a/manage/manager.go b/manage/manager.go index f04ae50..c7346c0 100755 --- a/manage/manager.go +++ b/manage/manager.go @@ -1,471 +1,471 @@ -package manage - -import ( - "context" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" - "github.com/go-oauth2/oauth2/v4/generates" - "github.com/go-oauth2/oauth2/v4/models" -) - -// NewDefaultManager create to default authorization management instance -func NewDefaultManager() *Manager { - m := NewManager() - // default implementation - m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) - m.MapAccessGenerate(generates.NewAccessGenerate()) - - return m -} - -// NewManager create to authorization management instance -func NewManager() *Manager { - return &Manager{ - gtcfg: make(map[oauth2.GrantType]*Config), - validateURI: DefaultValidateURI, - } -} - -// Manager provide authorization management -type Manager struct { - codeExp time.Duration - gtcfg map[oauth2.GrantType]*Config - rcfg *RefreshingConfig - validateURI ValidateURIHandler - authorizeGenerate oauth2.AuthorizeGenerate - accessGenerate oauth2.AccessGenerate - tokenStore oauth2.TokenStore - clientStore oauth2.ClientStore -} - -// get grant type config -func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { - if c, ok := m.gtcfg[gt]; ok && c != nil { - return c - } - switch gt { - case oauth2.AuthorizationCode: - return DefaultAuthorizeCodeTokenCfg - case oauth2.Implicit: - return DefaultImplicitTokenCfg - case oauth2.PasswordCredentials: - return DefaultPasswordTokenCfg - case oauth2.ClientCredentials: - return DefaultClientTokenCfg - } - return &Config{} -} - -// SetAuthorizeCodeExp set the authorization code expiration time -func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { - m.codeExp = exp -} - -// SetAuthorizeCodeTokenCfg set the authorization code grant token config -func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { - m.gtcfg[oauth2.AuthorizationCode] = cfg -} - -// SetImplicitTokenCfg set the implicit grant token config -func (m *Manager) SetImplicitTokenCfg(cfg *Config) { - m.gtcfg[oauth2.Implicit] = cfg -} - -// SetPasswordTokenCfg set the password grant token config -func (m *Manager) SetPasswordTokenCfg(cfg *Config) { - m.gtcfg[oauth2.PasswordCredentials] = cfg -} - -// SetClientTokenCfg set the client grant token config -func (m *Manager) SetClientTokenCfg(cfg *Config) { - m.gtcfg[oauth2.ClientCredentials] = cfg -} - -// SetRefreshTokenCfg set the refreshing token config -func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { - m.rcfg = cfg -} - -// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI -func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { - m.validateURI = handler -} - -// MapAuthorizeGenerate mapping the authorize code generate interface -func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { - m.authorizeGenerate = gen -} - -// MapAccessGenerate mapping the access token generate interface -func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { - m.accessGenerate = gen -} - -// MapClientStorage mapping the client store interface -func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { - m.clientStore = stor -} - -// MustClientStorage mandatory mapping the client store interface -func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { - if err != nil { - panic(err.Error()) - } - m.clientStore = stor -} - -// MapTokenStorage mapping the token store interface -func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { - m.tokenStore = stor -} - -// MustTokenStorage mandatory mapping the token store interface -func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { - if err != nil { - panic(err) - } - m.tokenStore = stor -} - -// GetClient get the client information -func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { - cli, err = m.clientStore.GetByID(ctx, clientID) - if err != nil { - return - } else if cli == nil { - err = errors.ErrInvalidClient - } - return -} - -// GenerateAuthToken generate the authorization token(code) -func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } else if tgr.RedirectURI != "" { - if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { - return nil, err - } - } - - ti := models.NewToken() - ti.SetClientID(tgr.ClientID) - ti.SetUserID(tgr.UserID) - ti.SetRedirectURI(tgr.RedirectURI) - ti.SetScope(tgr.Scope) - - createAt := time.Now() - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: tgr.UserID, - CreateAt: createAt, - TokenInfo: ti, - Request: tgr.Request, - } - switch rt { - case oauth2.Code: - codeExp := m.codeExp - if codeExp == 0 { - codeExp = DefaultCodeExp - } - ti.SetCodeCreateAt(createAt) - ti.SetCodeExpiresIn(codeExp) - if exp := tgr.AccessTokenExp; exp > 0 { - ti.SetAccessExpiresIn(exp) - } - - tv, err := m.authorizeGenerate.Token(ctx, td) - if err != nil { - return nil, err - } - ti.SetCode(tv) - case oauth2.Token: - // set access token expires - icfg := m.grantConfig(oauth2.Implicit) - aexp := icfg.AccessTokenExp - if exp := tgr.AccessTokenExp; exp > 0 { - aexp = exp - } - ti.SetAccessCreateAt(createAt) - ti.SetAccessExpiresIn(aexp) - - if icfg.IsGenerateRefresh { - ti.SetRefreshCreateAt(createAt) - ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) - } - - tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - ti.SetAccess(tv) - - if rv != "" { - ti.SetRefresh(rv) - } - } - - err = m.tokenStore.Create(ctx, ti) - if err != nil { - return nil, err - } - return ti, nil -} - -// get authorization code data -func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - ti, err := m.tokenStore.GetByCode(ctx, code) - if err != nil { - return nil, err - } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { - err = errors.ErrInvalidAuthorizeCode - return nil, errors.ErrInvalidAuthorizeCode - } - return ti, nil -} - -// delete authorization code data -func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { - return m.tokenStore.RemoveByCode(ctx, code) -} - -// get and delete authorization code data -func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - code := tgr.Code - ti, err := m.getAuthorizationCode(ctx, code) - if err != nil { - return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidAuthorizeCode - } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { - return nil, errors.ErrInvalidAuthorizeCode - } - - err = m.delAuthorizationCode(ctx, code) - if err != nil { - return nil, err - } - return ti, nil -} - -// GenerateAccessToken generate the access token -func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } - if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { - if !cliPass.VerifyPassword(tgr.ClientSecret) { - return nil, errors.ErrInvalidClient - } - } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient - } - if tgr.RedirectURI != "" { - if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { - return nil, err - } - } - - if gt == oauth2.AuthorizationCode { - ti, err := m.getAndDelAuthorizationCode(ctx, tgr) - if err != nil { - return nil, err - } - tgr.UserID = ti.GetUserID() - tgr.Scope = ti.GetScope() - if exp := ti.GetAccessExpiresIn(); exp > 0 { - tgr.AccessTokenExp = exp - } - } - - ti := models.NewToken() - ti.SetClientID(tgr.ClientID) - ti.SetUserID(tgr.UserID) - ti.SetRedirectURI(tgr.RedirectURI) - ti.SetScope(tgr.Scope) - - createAt := time.Now() - ti.SetAccessCreateAt(createAt) - - // set access token expires - gcfg := m.grantConfig(gt) - aexp := gcfg.AccessTokenExp - if exp := tgr.AccessTokenExp; exp > 0 { - aexp = exp - } - ti.SetAccessExpiresIn(aexp) - if gcfg.IsGenerateRefresh { - ti.SetRefreshCreateAt(createAt) - ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) - } - - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: tgr.UserID, - CreateAt: createAt, - TokenInfo: ti, - Request: tgr.Request, - } - - av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - ti.SetAccess(av) - - if rv != "" { - ti.SetRefresh(rv) - } - - err = m.tokenStore.Create(ctx, ti) - if err != nil { - return nil, err - } - - return ti, nil -} - -// RefreshAccessToken refreshing an access token -func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } else if tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient - } - - ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidRefreshToken - } - - oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() - - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: ti.GetUserID(), - CreateAt: time.Now(), - TokenInfo: ti, - Request: tgr.Request, - } - - rcfg := DefaultRefreshTokenCfg - if v := m.rcfg; v != nil { - rcfg = v - } - - ti.SetAccessCreateAt(td.CreateAt) - if v := rcfg.AccessTokenExp; v > 0 { - ti.SetAccessExpiresIn(v) - } - - if v := rcfg.RefreshTokenExp; v > 0 { - ti.SetRefreshExpiresIn(v) - } - - if rcfg.IsResetRefreshTime { - ti.SetRefreshCreateAt(td.CreateAt) - } - - if scope := tgr.Scope; scope != "" { - ti.SetScope(scope) - } - - tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - - ti.SetAccess(tv) - if rv != "" { - ti.SetRefresh(rv) - } - - if err := m.tokenStore.Create(ctx, ti); err != nil { - return nil, err - } - - if rcfg.IsRemoveAccess { - // remove the old access token - if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { - return nil, err - } - } - - if rcfg.IsRemoveRefreshing && rv != "" { - // remove the old refresh token - if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { - return nil, err - } - } - - if rv == "" { - ti.SetRefresh("") - ti.SetRefreshCreateAt(time.Now()) - ti.SetRefreshExpiresIn(0) - } - - return ti, nil -} - -// RemoveAccessToken use the access token to delete the token information -func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { - if access == "" { - return errors.ErrInvalidAccessToken - } - return m.tokenStore.RemoveByAccess(ctx, access) -} - -// RemoveRefreshToken use the refresh token to delete the token information -func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { - if refresh == "" { - return errors.ErrInvalidAccessToken - } - return m.tokenStore.RemoveByRefresh(ctx, refresh) -} - -// LoadAccessToken according to the access token for corresponding token information -func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { - if access == "" { - return nil, errors.ErrInvalidAccessToken - } - - ct := time.Now() - ti, err := m.tokenStore.GetByAccess(ctx, access) - if err != nil { - return nil, err - } else if ti == nil || ti.GetAccess() != access { - return nil, errors.ErrInvalidAccessToken - } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { - return nil, errors.ErrExpiredRefreshToken - } else if ti.GetAccessExpiresIn() != 0 && - ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { - return nil, errors.ErrExpiredAccessToken - } - return ti, nil -} - -// LoadRefreshToken according to the refresh token for corresponding token information -func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - if refresh == "" { - return nil, errors.ErrInvalidRefreshToken - } - - ti, err := m.tokenStore.GetByRefresh(ctx, refresh) - if err != nil { - return nil, err - } else if ti == nil || ti.GetRefresh() != refresh { - return nil, errors.ErrInvalidRefreshToken - } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { - return nil, errors.ErrExpiredRefreshToken - } - return ti, nil -} +package manage + +import ( + "context" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" + "github.com/go-oauth2/oauth2/v4/generates" + "github.com/go-oauth2/oauth2/v4/models" +) + +// NewDefaultManager create to default authorization management instance +func NewDefaultManager() *Manager { + m := NewManager() + // default implementation + m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) + m.MapAccessGenerate(generates.NewAccessGenerate()) + + return m +} + +// NewManager create to authorization management instance +func NewManager() *Manager { + return &Manager{ + gtcfg: make(map[oauth2.GrantType]*Config), + validateURI: DefaultValidateURI, + } +} + +// Manager provide authorization management +type Manager struct { + codeExp time.Duration + gtcfg map[oauth2.GrantType]*Config + rcfg *RefreshingConfig + validateURI ValidateURIHandler + authorizeGenerate oauth2.AuthorizeGenerate + accessGenerate oauth2.AccessGenerate + tokenStore oauth2.TokenStore + clientStore oauth2.ClientStore +} + +// get grant type config +func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { + if c, ok := m.gtcfg[gt]; ok && c != nil { + return c + } + switch gt { + case oauth2.AuthorizationCode: + return DefaultAuthorizeCodeTokenCfg + case oauth2.Implicit: + return DefaultImplicitTokenCfg + case oauth2.PasswordCredentials: + return DefaultPasswordTokenCfg + case oauth2.ClientCredentials: + return DefaultClientTokenCfg + } + return &Config{} +} + +// SetAuthorizeCodeExp set the authorization code expiration time +func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { + m.codeExp = exp +} + +// SetAuthorizeCodeTokenCfg set the authorization code grant token config +func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { + m.gtcfg[oauth2.AuthorizationCode] = cfg +} + +// SetImplicitTokenCfg set the implicit grant token config +func (m *Manager) SetImplicitTokenCfg(cfg *Config) { + m.gtcfg[oauth2.Implicit] = cfg +} + +// SetPasswordTokenCfg set the password grant token config +func (m *Manager) SetPasswordTokenCfg(cfg *Config) { + m.gtcfg[oauth2.PasswordCredentials] = cfg +} + +// SetClientTokenCfg set the client grant token config +func (m *Manager) SetClientTokenCfg(cfg *Config) { + m.gtcfg[oauth2.ClientCredentials] = cfg +} + +// SetRefreshTokenCfg set the refreshing token config +func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { + m.rcfg = cfg +} + +// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI +func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { + m.validateURI = handler +} + +// MapAuthorizeGenerate mapping the authorize code generate interface +func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { + m.authorizeGenerate = gen +} + +// MapAccessGenerate mapping the access token generate interface +func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { + m.accessGenerate = gen +} + +// MapClientStorage mapping the client store interface +func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { + m.clientStore = stor +} + +// MustClientStorage mandatory mapping the client store interface +func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { + if err != nil { + panic(err.Error()) + } + m.clientStore = stor +} + +// MapTokenStorage mapping the token store interface +func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { + m.tokenStore = stor +} + +// MustTokenStorage mandatory mapping the token store interface +func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { + if err != nil { + panic(err) + } + m.tokenStore = stor +} + +// GetClient get the client information +func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { + cli, err = m.clientStore.GetByID(ctx, clientID) + if err != nil { + return + } else if cli == nil { + err = errors.ErrInvalidClient + } + return +} + +// GenerateAuthToken generate the authorization token(code) +func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } else if tgr.RedirectURI != "" { + if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { + return nil, err + } + } + + ti := models.NewToken() + ti.SetClientID(tgr.ClientID) + ti.SetUserID(tgr.UserID) + ti.SetRedirectURI(tgr.RedirectURI) + ti.SetScope(tgr.Scope) + + createAt := time.Now() + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: tgr.UserID, + CreateAt: createAt, + TokenInfo: ti, + Request: tgr.Request, + } + switch rt { + case oauth2.Code: + codeExp := m.codeExp + if codeExp == 0 { + codeExp = DefaultCodeExp + } + ti.SetCodeCreateAt(createAt) + ti.SetCodeExpiresIn(codeExp) + if exp := tgr.AccessTokenExp; exp > 0 { + ti.SetAccessExpiresIn(exp) + } + + tv, err := m.authorizeGenerate.Token(ctx, td) + if err != nil { + return nil, err + } + ti.SetCode(tv) + case oauth2.Token: + // set access token expires + icfg := m.grantConfig(oauth2.Implicit) + aexp := icfg.AccessTokenExp + if exp := tgr.AccessTokenExp; exp > 0 { + aexp = exp + } + ti.SetAccessCreateAt(createAt) + ti.SetAccessExpiresIn(aexp) + + if icfg.IsGenerateRefresh { + ti.SetRefreshCreateAt(createAt) + ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) + } + + tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + ti.SetAccess(tv) + + if rv != "" { + ti.SetRefresh(rv) + } + } + + err = m.tokenStore.Create(ctx, ti) + if err != nil { + return nil, err + } + return ti, nil +} + +// get authorization code data +func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { + ti, err := m.tokenStore.GetByCode(ctx, code) + if err != nil { + return nil, err + } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { + err = errors.ErrInvalidAuthorizeCode + return nil, errors.ErrInvalidAuthorizeCode + } + return ti, nil +} + +// delete authorization code data +func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { + return m.tokenStore.RemoveByCode(ctx, code) +} + +// get and delete authorization code data +func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + code := tgr.Code + ti, err := m.getAuthorizationCode(ctx, code) + if err != nil { + return nil, err + } else if ti.GetClientID() != tgr.ClientID { + return nil, errors.ErrInvalidAuthorizeCode + } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { + return nil, errors.ErrInvalidAuthorizeCode + } + + err = m.delAuthorizationCode(ctx, code) + if err != nil { + return nil, err + } + return ti, nil +} + +// GenerateAccessToken generate the access token +func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } + if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { + if !cliPass.VerifyPassword(tgr.ClientSecret) { + return nil, errors.ErrInvalidClient + } + } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { + return nil, errors.ErrInvalidClient + } + if tgr.RedirectURI != "" { + if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { + return nil, err + } + } + + if gt == oauth2.AuthorizationCode { + ti, err := m.getAndDelAuthorizationCode(ctx, tgr) + if err != nil { + return nil, err + } + tgr.UserID = ti.GetUserID() + tgr.Scope = ti.GetScope() + if exp := ti.GetAccessExpiresIn(); exp > 0 { + tgr.AccessTokenExp = exp + } + } + + ti := models.NewToken() + ti.SetClientID(tgr.ClientID) + ti.SetUserID(tgr.UserID) + ti.SetRedirectURI(tgr.RedirectURI) + ti.SetScope(tgr.Scope) + + createAt := time.Now() + ti.SetAccessCreateAt(createAt) + + // set access token expires + gcfg := m.grantConfig(gt) + aexp := gcfg.AccessTokenExp + if exp := tgr.AccessTokenExp; exp > 0 { + aexp = exp + } + ti.SetAccessExpiresIn(aexp) + if gcfg.IsGenerateRefresh { + ti.SetRefreshCreateAt(createAt) + ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) + } + + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: tgr.UserID, + CreateAt: createAt, + TokenInfo: ti, + Request: tgr.Request, + } + + av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + ti.SetAccess(av) + + if rv != "" { + ti.SetRefresh(rv) + } + + err = m.tokenStore.Create(ctx, ti) + if err != nil { + return nil, err + } + + return ti, nil +} + +// RefreshAccessToken refreshing an access token +func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } else if tgr.ClientSecret != cli.GetSecret() { + return nil, errors.ErrInvalidClient + } + + ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + return nil, err + } else if ti.GetClientID() != tgr.ClientID { + return nil, errors.ErrInvalidRefreshToken + } + + oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() + + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: ti.GetUserID(), + CreateAt: time.Now(), + TokenInfo: ti, + Request: tgr.Request, + } + + rcfg := DefaultRefreshTokenCfg + if v := m.rcfg; v != nil { + rcfg = v + } + + ti.SetAccessCreateAt(td.CreateAt) + if v := rcfg.AccessTokenExp; v > 0 { + ti.SetAccessExpiresIn(v) + } + + if v := rcfg.RefreshTokenExp; v > 0 { + ti.SetRefreshExpiresIn(v) + } + + if rcfg.IsResetRefreshTime { + ti.SetRefreshCreateAt(td.CreateAt) + } + + if scope := tgr.Scope; scope != "" { + ti.SetScope(scope) + } + + tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + + ti.SetAccess(tv) + if rv != "" { + ti.SetRefresh(rv) + } + + if err := m.tokenStore.Create(ctx, ti); err != nil { + return nil, err + } + + if rcfg.IsRemoveAccess { + // remove the old access token + if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { + return nil, err + } + } + + if rcfg.IsRemoveRefreshing && rv != "" { + // remove the old refresh token + if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { + return nil, err + } + } + + if rv == "" { + ti.SetRefresh("") + ti.SetRefreshCreateAt(time.Now()) + ti.SetRefreshExpiresIn(0) + } + + return ti, nil +} + +// RemoveAccessToken use the access token to delete the token information +func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { + if access == "" { + return errors.ErrInvalidAccessToken + } + return m.tokenStore.RemoveByAccess(ctx, access) +} + +// RemoveRefreshToken use the refresh token to delete the token information +func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { + if refresh == "" { + return errors.ErrInvalidAccessToken + } + return m.tokenStore.RemoveByRefresh(ctx, refresh) +} + +// LoadAccessToken according to the access token for corresponding token information +func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { + if access == "" { + return nil, errors.ErrInvalidAccessToken + } + + ct := time.Now() + ti, err := m.tokenStore.GetByAccess(ctx, access) + if err != nil { + return nil, err + } else if ti == nil || ti.GetAccess() != access { + return nil, errors.ErrInvalidAccessToken + } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && + ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { + return nil, errors.ErrExpiredRefreshToken + } else if ti.GetAccessExpiresIn() != 0 && + ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { + return nil, errors.ErrExpiredAccessToken + } + return ti, nil +} + +// LoadRefreshToken according to the refresh token for corresponding token information +func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { + if refresh == "" { + return nil, errors.ErrInvalidRefreshToken + } + + ti, err := m.tokenStore.GetByRefresh(ctx, refresh) + if err != nil { + return nil, err + } else if ti == nil || ti.GetRefresh() != refresh { + return nil, errors.ErrInvalidRefreshToken + } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire + ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { + return nil, errors.ErrExpiredRefreshToken + } + return ti, nil +} diff --git a/server/handler.go b/server/handler.go index df66137..a9202f8 100755 --- a/server/handler.go +++ b/server/handler.go @@ -1,63 +1,63 @@ -package server - -import ( - "net/http" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -type ( - // ClientInfoHandler get client info from request - ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) - - // ClientAuthorizedHandler check the client allows to use this authorization grant type - ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) - - // ClientScopeHandler check the client allows to use scope - ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) - - // UserAuthorizationHandler get user id from request authorization - UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) - - // PasswordAuthorizationHandler get user id from username and password - PasswordAuthorizationHandler func(username, password string) (userID string, err error) - - // RefreshingScopeHandler check the scope of the refreshing token - RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) - - // ResponseErrorHandler response error handing - ResponseErrorHandler func(re *errors.Response) - - // InternalErrorHandler internal error handing - InternalErrorHandler func(err error) (re *errors.Response) - - // AuthorizeScopeHandler set the authorized scope - AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) - - // AccessTokenExpHandler set expiration date for the access token - AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) - - // ExtensionFieldsHandler in response to the access token with the extension of the field - ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) -) - -// ClientFormHandler get client data from form -func ClientFormHandler(r *http.Request) (string, string, error) { - clientID := r.Form.Get("client_id") - if clientID == "" { - return "", "", errors.ErrInvalidClient - } - clientSecret := r.Form.Get("client_secret") - return clientID, clientSecret, nil -} - -// ClientBasicHandler get client data from basic authorization -func ClientBasicHandler(r *http.Request) (string, string, error) { - username, password, ok := r.BasicAuth() - if !ok { - return "", "", errors.ErrInvalidClient - } - return username, password, nil -} +package server + +import ( + "net/http" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +type ( + // ClientInfoHandler get client info from request + ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) + + // ClientAuthorizedHandler check the client allows to use this authorization grant type + ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) + + // ClientScopeHandler check the client allows to use scope + ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) + + // UserAuthorizationHandler get user id from request authorization + UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) + + // PasswordAuthorizationHandler get user id from username and password + PasswordAuthorizationHandler func(username, password string) (userID string, err error) + + // RefreshingScopeHandler check the scope of the refreshing token + RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) + + // ResponseErrorHandler response error handing + ResponseErrorHandler func(re *errors.Response) + + // InternalErrorHandler internal error handing + InternalErrorHandler func(err error) (re *errors.Response) + + // AuthorizeScopeHandler set the authorized scope + AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) + + // AccessTokenExpHandler set expiration date for the access token + AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) + + // ExtensionFieldsHandler in response to the access token with the extension of the field + ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) +) + +// ClientFormHandler get client data from form +func ClientFormHandler(r *http.Request) (string, string, error) { + clientID := r.Form.Get("client_id") + if clientID == "" { + return "", "", errors.ErrInvalidClient + } + clientSecret := r.Form.Get("client_secret") + return clientID, clientSecret, nil +} + +// ClientBasicHandler get client data from basic authorization +func ClientBasicHandler(r *http.Request) (string, string, error) { + username, password, ok := r.BasicAuth() + if !ok { + return "", "", errors.ErrInvalidClient + } + return username, password, nil +} diff --git a/server/server.go b/server/server.go index 6887431..5a05ca7 100755 --- a/server/server.go +++ b/server/server.go @@ -1,530 +1,530 @@ -package server - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -// NewDefaultServer create a default authorization server -func NewDefaultServer(manager oauth2.Manager) *Server { - return NewServer(NewConfig(), manager) -} - -// NewServer create authorization server -func NewServer(cfg *Config, manager oauth2.Manager) *Server { - srv := &Server{ - Config: cfg, - Manager: manager, - } - - // default handler - srv.ClientInfoHandler = ClientBasicHandler - - srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { - return "", errors.ErrAccessDenied - } - - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { - return "", errors.ErrAccessDenied - } - return srv -} - -// Server Provide authorization server -type Server struct { - Config *Config - Manager oauth2.Manager - ClientInfoHandler ClientInfoHandler - ClientAuthorizedHandler ClientAuthorizedHandler - ClientScopeHandler ClientScopeHandler - UserAuthorizationHandler UserAuthorizationHandler - PasswordAuthorizationHandler PasswordAuthorizationHandler - RefreshingScopeHandler RefreshingScopeHandler - ResponseErrorHandler ResponseErrorHandler - InternalErrorHandler InternalErrorHandler - ExtensionFieldsHandler ExtensionFieldsHandler - AccessTokenExpHandler AccessTokenExpHandler - AuthorizeScopeHandler AuthorizeScopeHandler -} - -func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { - if req == nil { - return err - } - data, _, _ := s.GetErrorData(err) - return s.redirect(w, req, data) -} - -func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { - uri, err := s.GetRedirectURI(req, data) - if err != nil { - return err - } - - w.Header().Set("Location", uri) - w.WriteHeader(302) - return nil -} - -func (s *Server) tokenError(w http.ResponseWriter, err error) error { - data, statusCode, header := s.GetErrorData(err) - return s.token(w, data, header, statusCode) -} - -func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { - w.Header().Set("Content-Type", "application/json;charset=UTF-8") - w.Header().Set("Cache-Control", "no-store") - w.Header().Set("Pragma", "no-cache") - - for key := range header { - w.Header().Set(key, header.Get(key)) - } - - status := http.StatusOK - if len(statusCode) > 0 && statusCode[0] > 0 { - status = statusCode[0] - } - - w.WriteHeader(status) - return json.NewEncoder(w).Encode(data) -} - -// GetRedirectURI get redirect uri -func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { - u, err := url.Parse(req.RedirectURI) - if err != nil { - return "", err - } - - q := u.Query() - if req.State != "" { - q.Set("state", req.State) - } - - for k, v := range data { - q.Set(k, fmt.Sprint(v)) - } - - switch req.ResponseType { - case oauth2.Code: - u.RawQuery = q.Encode() - case oauth2.Token: - u.RawQuery = "" - fragment, err := url.QueryUnescape(q.Encode()) - if err != nil { - return "", err - } - u.Fragment = fragment - } - - return u.String(), nil -} - -// CheckResponseType check allows response type -func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.Config.AllowedResponseTypes { - if art == rt { - return true - } - } - return false -} - -// ValidationAuthorizeRequest the authorization request validation -func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { - redirectURI := r.FormValue("redirect_uri") - clientID := r.FormValue("client_id") - if !(r.Method == "GET" || r.Method == "POST") || - clientID == "" { - return nil, errors.ErrInvalidRequest - } - - resType := oauth2.ResponseType(r.FormValue("response_type")) - if resType.String() == "" { - return nil, errors.ErrUnsupportedResponseType - } else if allowed := s.CheckResponseType(resType); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - req := &AuthorizeRequest{ - RedirectURI: redirectURI, - ResponseType: resType, - ClientID: clientID, - State: r.FormValue("state"), - Scope: r.FormValue("scope"), - Request: r, - } - return req, nil -} - -// GetAuthorizeToken get authorization token(code) -func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { - // check the client allows the grant type - if fn := s.ClientAuthorizedHandler; fn != nil { - gt := oauth2.AuthorizationCode - if req.ResponseType == oauth2.Token { - gt = oauth2.Implicit - } - - allowed, err := fn(req.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, - } - - // check the client allows the authorized scope - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) -} - -// GetAuthorizeData get authorization response data -func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { - if rt == oauth2.Code { - return map[string]interface{}{ - "code": ti.GetCode(), - } - } - return s.GetTokenData(ti) -} - -// HandleAuthorizeRequest the authorization request handling -func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - req, err := s.ValidationAuthorizeRequest(r) - if err != nil { - return s.redirectError(w, req, err) - } - - // user authorization - userID, err := s.UserAuthorizationHandler(w, r) - if err != nil { - return s.redirectError(w, req, err) - } else if userID == "" { - return nil - } - req.UserID = userID - - // specify the scope of authorization - if fn := s.AuthorizeScopeHandler; fn != nil { - scope, err := fn(w, r) - if err != nil { - return err - } else if scope != "" { - req.Scope = scope - } - } - - // specify the expiration time of access token - if fn := s.AccessTokenExpHandler; fn != nil { - exp, err := fn(w, r) - if err != nil { - return err - } - req.AccessTokenExp = exp - } - - ti, err := s.GetAuthorizeToken(ctx, req) - if err != nil { - return s.redirectError(w, req, err) - } - - // If the redirect URI is empty, the default domain provided by the client is used. - if req.RedirectURI == "" { - client, err := s.Manager.GetClient(ctx, req.ClientID) - if err != nil { - return err - } - req.RedirectURI = client.GetDomain() - } - - return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) -} - -// ValidationTokenRequest the token request validation -func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { - if v := r.Method; !(v == "POST" || - (s.Config.AllowGetAccessRequest && v == "GET")) { - return "", nil, errors.ErrInvalidRequest - } - - gt := oauth2.GrantType(r.FormValue("grant_type")) - if gt.String() == "" { - return "", nil, errors.ErrUnsupportedGrantType - } - - clientID, clientSecret, err := s.ClientInfoHandler(r) - if err != nil { - return "", nil, err - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: clientID, - ClientSecret: clientSecret, - Request: r, - } - - switch gt { - case oauth2.AuthorizationCode: - tgr.RedirectURI = r.FormValue("redirect_uri") - tgr.Code = r.FormValue("code") - if tgr.RedirectURI == "" || - tgr.Code == "" { - return "", nil, errors.ErrInvalidRequest - } - case oauth2.PasswordCredentials: - tgr.Scope = r.FormValue("scope") - username, password := r.FormValue("username"), r.FormValue("password") - if username == "" || password == "" { - return "", nil, errors.ErrInvalidRequest - } - - userID, err := s.PasswordAuthorizationHandler(username, password) - if err != nil { - return "", nil, err - } else if userID == "" { - return "", nil, errors.ErrInvalidGrant - } - tgr.UserID = userID - case oauth2.ClientCredentials: - tgr.Scope = r.FormValue("scope") - case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") - tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest - } - } - return gt, tgr, nil -} - -// CheckGrantType check allows grant type -func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { - for _, agt := range s.Config.AllowedGrantTypes { - if agt == gt { - return true - } - } - return false -} - -// GetAccessToken access token -func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - if allowed := s.CheckGrantType(gt); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - if fn := s.ClientAuthorizedHandler; fn != nil { - allowed, err := fn(tgr.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - switch gt { - case oauth2.AuthorizationCode: - ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) - if err != nil { - switch err { - case errors.ErrInvalidAuthorizeCode: - return nil, errors.ErrInvalidGrant - case errors.ErrInvalidClient: - return nil, errors.ErrInvalidClient - default: - return nil, err - } - } - return ti, nil - case oauth2.PasswordCredentials, oauth2.ClientCredentials: - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - return s.Manager.GenerateAccessToken(ctx, gt, tgr) - case oauth2.Refreshing: - // check scope - if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { - rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - - allowed, err := scopeFn(tgr, rti.GetScope()) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - ti, err := s.Manager.RefreshAccessToken(ctx, tgr) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - return ti, nil - } - - return nil, errors.ErrUnsupportedGrantType -} - -// GetTokenData token data -func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { - data := map[string]interface{}{ - "access_token": ti.GetAccess(), - "token_type": s.Config.TokenType, - "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), - } - - if scope := ti.GetScope(); scope != "" { - data["scope"] = scope - } - - if refresh := ti.GetRefresh(); refresh != "" { - data["refresh_token"] = refresh - } - - if fn := s.ExtensionFieldsHandler; fn != nil { - ext := fn(ti) - for k, v := range ext { - if _, ok := data[k]; ok { - continue - } - data[k] = v - } - } - return data -} - -// HandleTokenRequest token request handling -func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - gt, tgr, err := s.ValidationTokenRequest(r) - if err != nil { - return s.tokenError(w, err) - } - - ti, err := s.GetAccessToken(ctx, gt, tgr) - if err != nil { - return s.tokenError(w, err) - } - - return s.token(w, s.GetTokenData(ti), nil) -} - -// GetErrorData get error response data -func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { - var re errors.Response - if v, ok := errors.Descriptions[err]; ok { - re.Error = err - re.Description = v - re.StatusCode = errors.StatusCodes[err] - } else { - if fn := s.InternalErrorHandler; fn != nil { - if v := fn(err); v != nil { - re = *v - } - } - - if re.Error == nil { - re.Error = errors.ErrServerError - re.Description = errors.Descriptions[errors.ErrServerError] - re.StatusCode = errors.StatusCodes[errors.ErrServerError] - } - } - - if fn := s.ResponseErrorHandler; fn != nil { - fn(&re) - } - - data := make(map[string]interface{}) - if err := re.Error; err != nil { - data["error"] = err.Error() - } - - if v := re.ErrorCode; v != 0 { - data["error_code"] = v - } - - if v := re.Description; v != "" { - data["error_description"] = v - } - - if v := re.URI; v != "" { - data["error_uri"] = v - } - - statusCode := http.StatusInternalServerError - if v := re.StatusCode; v > 0 { - statusCode = v - } - - return data, statusCode, re.Header -} - -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - -// ValidationBearerToken validation the bearer tokens -// https://tools.ietf.org/html/rfc6750 -func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { - ctx := r.Context() - - accessToken, ok := s.BearerAuth(r) - if !ok { - return nil, errors.ErrInvalidAccessToken - } - - return s.Manager.LoadAccessToken(ctx, accessToken) -} +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +// NewDefaultServer create a default authorization server +func NewDefaultServer(manager oauth2.Manager) *Server { + return NewServer(NewConfig(), manager) +} + +// NewServer create authorization server +func NewServer(cfg *Config, manager oauth2.Manager) *Server { + srv := &Server{ + Config: cfg, + Manager: manager, + } + + // default handler + srv.ClientInfoHandler = ClientBasicHandler + + srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { + return "", errors.ErrAccessDenied + } + + srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { + return "", errors.ErrAccessDenied + } + return srv +} + +// Server Provide authorization server +type Server struct { + Config *Config + Manager oauth2.Manager + ClientInfoHandler ClientInfoHandler + ClientAuthorizedHandler ClientAuthorizedHandler + ClientScopeHandler ClientScopeHandler + UserAuthorizationHandler UserAuthorizationHandler + PasswordAuthorizationHandler PasswordAuthorizationHandler + RefreshingScopeHandler RefreshingScopeHandler + ResponseErrorHandler ResponseErrorHandler + InternalErrorHandler InternalErrorHandler + ExtensionFieldsHandler ExtensionFieldsHandler + AccessTokenExpHandler AccessTokenExpHandler + AuthorizeScopeHandler AuthorizeScopeHandler +} + +func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { + if req == nil { + return err + } + data, _, _ := s.GetErrorData(err) + return s.redirect(w, req, data) +} + +func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { + uri, err := s.GetRedirectURI(req, data) + if err != nil { + return err + } + + w.Header().Set("Location", uri) + w.WriteHeader(302) + return nil +} + +func (s *Server) tokenError(w http.ResponseWriter, err error) error { + data, statusCode, header := s.GetErrorData(err) + return s.token(w, data, header, statusCode) +} + +func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { + w.Header().Set("Content-Type", "application/json;charset=UTF-8") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + + for key := range header { + w.Header().Set(key, header.Get(key)) + } + + status := http.StatusOK + if len(statusCode) > 0 && statusCode[0] > 0 { + status = statusCode[0] + } + + w.WriteHeader(status) + return json.NewEncoder(w).Encode(data) +} + +// GetRedirectURI get redirect uri +func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { + u, err := url.Parse(req.RedirectURI) + if err != nil { + return "", err + } + + q := u.Query() + if req.State != "" { + q.Set("state", req.State) + } + + for k, v := range data { + q.Set(k, fmt.Sprint(v)) + } + + switch req.ResponseType { + case oauth2.Code: + u.RawQuery = q.Encode() + case oauth2.Token: + u.RawQuery = "" + fragment, err := url.QueryUnescape(q.Encode()) + if err != nil { + return "", err + } + u.Fragment = fragment + } + + return u.String(), nil +} + +// CheckResponseType check allows response type +func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { + for _, art := range s.Config.AllowedResponseTypes { + if art == rt { + return true + } + } + return false +} + +// ValidationAuthorizeRequest the authorization request validation +func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + if !(r.Method == "GET" || r.Method == "POST") || + clientID == "" { + return nil, errors.ErrInvalidRequest + } + + resType := oauth2.ResponseType(r.FormValue("response_type")) + if resType.String() == "" { + return nil, errors.ErrUnsupportedResponseType + } else if allowed := s.CheckResponseType(resType); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + req := &AuthorizeRequest{ + RedirectURI: redirectURI, + ResponseType: resType, + ClientID: clientID, + State: r.FormValue("state"), + Scope: r.FormValue("scope"), + Request: r, + } + return req, nil +} + +// GetAuthorizeToken get authorization token(code) +func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { + // check the client allows the grant type + if fn := s.ClientAuthorizedHandler; fn != nil { + gt := oauth2.AuthorizationCode + if req.ResponseType == oauth2.Token { + gt = oauth2.Implicit + } + + allowed, err := fn(req.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + } + + // check the client allows the authorized scope + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) +} + +// GetAuthorizeData get authorization response data +func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { + if rt == oauth2.Code { + return map[string]interface{}{ + "code": ti.GetCode(), + } + } + return s.GetTokenData(ti) +} + +// HandleAuthorizeRequest the authorization request handling +func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + req, err := s.ValidationAuthorizeRequest(r) + if err != nil { + return s.redirectError(w, req, err) + } + + // user authorization + userID, err := s.UserAuthorizationHandler(w, r) + if err != nil { + return s.redirectError(w, req, err) + } else if userID == "" { + return nil + } + req.UserID = userID + + // specify the scope of authorization + if fn := s.AuthorizeScopeHandler; fn != nil { + scope, err := fn(w, r) + if err != nil { + return err + } else if scope != "" { + req.Scope = scope + } + } + + // specify the expiration time of access token + if fn := s.AccessTokenExpHandler; fn != nil { + exp, err := fn(w, r) + if err != nil { + return err + } + req.AccessTokenExp = exp + } + + ti, err := s.GetAuthorizeToken(ctx, req) + if err != nil { + return s.redirectError(w, req, err) + } + + // If the redirect URI is empty, the default domain provided by the client is used. + if req.RedirectURI == "" { + client, err := s.Manager.GetClient(ctx, req.ClientID) + if err != nil { + return err + } + req.RedirectURI = client.GetDomain() + } + + return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) +} + +// ValidationTokenRequest the token request validation +func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { + if v := r.Method; !(v == "POST" || + (s.Config.AllowGetAccessRequest && v == "GET")) { + return "", nil, errors.ErrInvalidRequest + } + + gt := oauth2.GrantType(r.FormValue("grant_type")) + if gt.String() == "" { + return "", nil, errors.ErrUnsupportedGrantType + } + + clientID, clientSecret, err := s.ClientInfoHandler(r) + if err != nil { + return "", nil, err + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + Request: r, + } + + switch gt { + case oauth2.AuthorizationCode: + tgr.RedirectURI = r.FormValue("redirect_uri") + tgr.Code = r.FormValue("code") + if tgr.RedirectURI == "" || + tgr.Code == "" { + return "", nil, errors.ErrInvalidRequest + } + case oauth2.PasswordCredentials: + tgr.Scope = r.FormValue("scope") + username, password := r.FormValue("username"), r.FormValue("password") + if username == "" || password == "" { + return "", nil, errors.ErrInvalidRequest + } + + userID, err := s.PasswordAuthorizationHandler(username, password) + if err != nil { + return "", nil, err + } else if userID == "" { + return "", nil, errors.ErrInvalidGrant + } + tgr.UserID = userID + case oauth2.ClientCredentials: + tgr.Scope = r.FormValue("scope") + case oauth2.Refreshing: + tgr.Refresh = r.FormValue("refresh_token") + tgr.Scope = r.FormValue("scope") + if tgr.Refresh == "" { + return "", nil, errors.ErrInvalidRequest + } + } + return gt, tgr, nil +} + +// CheckGrantType check allows grant type +func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { + for _, agt := range s.Config.AllowedGrantTypes { + if agt == gt { + return true + } + } + return false +} + +// GetAccessToken access token +func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + if allowed := s.CheckGrantType(gt); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + if fn := s.ClientAuthorizedHandler; fn != nil { + allowed, err := fn(tgr.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + switch gt { + case oauth2.AuthorizationCode: + ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) + if err != nil { + switch err { + case errors.ErrInvalidAuthorizeCode: + return nil, errors.ErrInvalidGrant + case errors.ErrInvalidClient: + return nil, errors.ErrInvalidClient + default: + return nil, err + } + } + return ti, nil + case oauth2.PasswordCredentials, oauth2.ClientCredentials: + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + return s.Manager.GenerateAccessToken(ctx, gt, tgr) + case oauth2.Refreshing: + // check scope + if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { + rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + + allowed, err := scopeFn(tgr, rti.GetScope()) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + ti, err := s.Manager.RefreshAccessToken(ctx, tgr) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + return ti, nil + } + + return nil, errors.ErrUnsupportedGrantType +} + +// GetTokenData token data +func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { + data := map[string]interface{}{ + "access_token": ti.GetAccess(), + "token_type": s.Config.TokenType, + "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), + } + + if scope := ti.GetScope(); scope != "" { + data["scope"] = scope + } + + if refresh := ti.GetRefresh(); refresh != "" { + data["refresh_token"] = refresh + } + + if fn := s.ExtensionFieldsHandler; fn != nil { + ext := fn(ti) + for k, v := range ext { + if _, ok := data[k]; ok { + continue + } + data[k] = v + } + } + return data +} + +// HandleTokenRequest token request handling +func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + gt, tgr, err := s.ValidationTokenRequest(r) + if err != nil { + return s.tokenError(w, err) + } + + ti, err := s.GetAccessToken(ctx, gt, tgr) + if err != nil { + return s.tokenError(w, err) + } + + return s.token(w, s.GetTokenData(ti), nil) +} + +// GetErrorData get error response data +func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { + var re errors.Response + if v, ok := errors.Descriptions[err]; ok { + re.Error = err + re.Description = v + re.StatusCode = errors.StatusCodes[err] + } else { + if fn := s.InternalErrorHandler; fn != nil { + if v := fn(err); v != nil { + re = *v + } + } + + if re.Error == nil { + re.Error = errors.ErrServerError + re.Description = errors.Descriptions[errors.ErrServerError] + re.StatusCode = errors.StatusCodes[errors.ErrServerError] + } + } + + if fn := s.ResponseErrorHandler; fn != nil { + fn(&re) + } + + data := make(map[string]interface{}) + if err := re.Error; err != nil { + data["error"] = err.Error() + } + + if v := re.ErrorCode; v != 0 { + data["error_code"] = v + } + + if v := re.Description; v != "" { + data["error_description"] = v + } + + if v := re.URI; v != "" { + data["error_uri"] = v + } + + statusCode := http.StatusInternalServerError + if v := re.StatusCode; v > 0 { + statusCode = v + } + + return data, statusCode, re.Header +} + +// BearerAuth parse bearer token +func (s *Server) BearerAuth(r *http.Request) (string, bool) { + auth := r.Header.Get("Authorization") + prefix := "Bearer " + token := "" + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +// ValidationBearerToken validation the bearer tokens +// https://tools.ietf.org/html/rfc6750 +func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { + ctx := r.Context() + + accessToken, ok := s.BearerAuth(r) + if !ok { + return nil, errors.ErrInvalidAccessToken + } + + return s.Manager.LoadAccessToken(ctx, accessToken) +}