Skip to content

Commit

Permalink
feat: SSO improve auth handling (chainloop-dev#1862)
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Martinez <[email protected]>
  • Loading branch information
migmartri authored Feb 28, 2025
1 parent b9c1f32 commit 9673635
Showing 1 changed file with 58 additions and 32 deletions.
90 changes: 58 additions & 32 deletions app/controlplane/internal/service/auth.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright 2024 The Chainloop Authors.
// Copyright 2024-2025 The Chainloop Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,8 +59,34 @@ const (
devUserDuration = 30 * longLivedDuration
)

type oauthResp struct {
code int
err error
showErrToUser bool
}

// This is used to provide by default a generic error message to the user
// unless showErrToUser is true
func (e *oauthResp) ErrorMessage(l *log.Helper) string {
if e.err != nil {
// If the error is an internal server error, log it and raise it masked
if e.code == http.StatusInternalServerError {
return sl.LogAndMaskErr(e.err, l).Error()
}
// otherwise return the error message to the user
// or the default status text
if e.showErrToUser {
return e.err.Error()
}

return http.StatusText(e.code)
}

return ""
}

type oauthHandler struct {
H func(*AuthService, http.ResponseWriter, *http.Request) (int, error)
H func(*AuthService, http.ResponseWriter, *http.Request) *oauthResp
svc *AuthService
}

Expand Down Expand Up @@ -164,17 +190,16 @@ func (svc *AuthService) RegisterLoginHandler() http.Handler {

// Implement http.Handler interface
func (h oauthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
status, err := h.H(h.svc, w, r)
if err != nil {
http.Error(w, http.StatusText(status), status)
if err := h.H(h.svc, w, r); err != nil {
http.Error(w, err.ErrorMessage(h.svc.log), err.code)
}
}

func loginHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (int, error) {
func loginHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) *oauthResp {
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
return http.StatusInternalServerError, sl.LogAndMaskErr(err, nil)
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to generate random string: %w", err), false}
}

// Store a random string to check it in the oauth callback
Expand All @@ -196,7 +221,7 @@ func loginHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (int
if connectionStr != "" {
uri, err := url.Parse(authorizationURI)
if err != nil {
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to parse authorization URI: %w", err), false}
}
q := uri.Query()
q.Set("connection", connectionStr)
Expand All @@ -205,7 +230,7 @@ func loginHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (int
}

http.Redirect(w, r, authorizationURI, http.StatusFound)
return http.StatusTemporaryRedirect, nil
return &oauthResp{http.StatusTemporaryRedirect, nil, false}
}

// Extract custom claims
Expand All @@ -231,35 +256,35 @@ func (c *upstreamOIDCclaims) preferredEmail() string {
return c.Email
}

type errorWithCode struct {
code int
error
}

func callbackHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (int, error) {
func callbackHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) *oauthResp {
ctx := context.Background()
// if OIDC provider returns an error, return it to and show it to the user
if desc := r.URL.Query().Get("error_description"); desc != "" {
return &oauthResp{http.StatusUnauthorized, errors.New(desc), true}
}

// Get information from google OIDC token
claims, errWithCode := extractUserInfoFromToken(ctx, svc, r)
if errWithCode != nil {
return errWithCode.code, sl.LogAndMaskErr(errWithCode.error, svc.log)
return &oauthResp{errWithCode.code, errWithCode.err, errWithCode.showErrToUser}
}

// Create user if needed
u, err := svc.userUseCase.FindOrCreateByEmail(ctx, claims.preferredEmail())
if err != nil {
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to find or create user: %w", err), false}
}

// Accept any pending invites
if err := svc.orgInvitesUseCase.AcceptPendingInvitations(ctx, u.Email); err != nil {
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to accept pending invitations: %w", err), false}
}

// Set the expiration
expiration := shortLivedDuration
longLived, err := r.Cookie(cookieLongLived)
if err != nil {
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to get long lived cookie: %w", err), false}
}

if longLived.Value == "true" {
Expand All @@ -269,32 +294,32 @@ func callbackHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (
// Generate user token
userToken, err := generateUserJWT(u.ID, svc.authConfig.GeneratedJwsHmacSecret, expiration)
if err != nil {
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to generate user token: %w", err), false}
}

// Either redirect or render the token if fallback is specified
// Callback URL from the cookie
callbackURLFromCookie, err := r.Cookie(cookieCallback)
if err != nil {
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to get callback URL from cookie: %w", err), false}
}

callbackValue := callbackURLFromCookie.Value

// There is no callback, just render the token
if callbackValue == "" {
fmt.Fprintf(w, "copy this token and paste it in your terminal window\n\n%s", userToken)
return http.StatusOK, nil
return &oauthResp{http.StatusOK, nil, false}
}

// Redirect to the callback URL
callbackURL, err := crafCallbackURL(callbackValue, userToken)
if err != nil {
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to craft callback URL: %w", err), false}
}

http.Redirect(w, r, callbackURL, http.StatusFound)
return http.StatusTemporaryRedirect, nil
return &oauthResp{http.StatusTemporaryRedirect, nil, false}
}

func crafCallbackURL(callback, userToken string) (string, error) {
Expand All @@ -311,14 +336,15 @@ func crafCallbackURL(callback, userToken string) (string, error) {
}

// Returns the claims from the OIDC token received during the OIDC callback
func extractUserInfoFromToken(ctx context.Context, svc *AuthService, r *http.Request) (*upstreamOIDCclaims, *errorWithCode) {
func extractUserInfoFromToken(ctx context.Context, svc *AuthService, r *http.Request) (*upstreamOIDCclaims, *oauthResp) {
cookieState, err := r.Cookie(cookieOauthStateName)
// if the cookie is not found, it likely means the authentication process has expired
if err != nil {
return nil, &errorWithCode{http.StatusUnauthorized, fmt.Errorf("retrieving cookie %s: %w", cookieOauthStateName, err)}
return nil, &oauthResp{http.StatusUnauthorized, errors.New("the authentication process has expired, please try again"), true}
}

if r.URL.Query().Get("state") != cookieState.Value {
return nil, &errorWithCode{http.StatusUnauthorized, errors.New("oauth state does not match")}
return nil, &oauthResp{http.StatusUnauthorized, errors.New("the authentication was invalid, please try again"), true}
}

code := r.URL.Query().Get("code")
Expand All @@ -329,23 +355,23 @@ func extractUserInfoFromToken(ctx context.Context, svc *AuthService, r *http.Req
// Exchange the code for a token
oauth2Token, err := svc.authenticator.Exchange(ctx, code)
if err != nil {
return nil, &errorWithCode{http.StatusUnauthorized, err}
return nil, &oauthResp{http.StatusUnauthorized, err, false}
}

// It's a valid Oauth2 token
if !oauth2Token.Valid() {
return nil, &errorWithCode{http.StatusUnauthorized, errors.New("retrieved invalid Token")}
return nil, &oauthResp{http.StatusUnauthorized, errors.New("retrieved invalid Token"), false}
}

// Parse and verify ID token content and signature
idToken, err := svc.authenticator.VerifyIDToken(ctx, oauth2Token)
if err != nil {
return nil, &errorWithCode{http.StatusInternalServerError, err}
return nil, &oauthResp{http.StatusInternalServerError, err, false}
}

var claims *upstreamOIDCclaims
if err := idToken.Claims(&claims); err != nil {
return nil, &errorWithCode{http.StatusInternalServerError, err}
return nil, &oauthResp{http.StatusInternalServerError, err, false}
}

return claims, nil
Expand All @@ -367,7 +393,7 @@ func generateUserJWT(userID, passphrase string, expiration time.Duration) (strin
}

func setOauthCookie(w http.ResponseWriter, name, value string) {
http.SetCookie(w, &http.Cookie{Name: name, Value: value, Path: "/", Expires: time.Now().Add(5 * time.Minute)})
http.SetCookie(w, &http.Cookie{Name: name, Value: value, Path: "/", Expires: time.Now().Add(10 * time.Minute)})
}

func generateAndLogDevUser(userUC *biz.UserUseCase, log *log.Helper, authConfig *conf.Auth) error {
Expand Down

0 comments on commit 9673635

Please sign in to comment.