Skip to content

Commit 9673635

Browse files
authored
feat: SSO improve auth handling (chainloop-dev#1862)
Signed-off-by: Miguel Martinez <[email protected]>
1 parent b9c1f32 commit 9673635

File tree

1 file changed

+58
-32
lines changed
  • app/controlplane/internal/service

1 file changed

+58
-32
lines changed

app/controlplane/internal/service/auth.go

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// Copyright 2024 The Chainloop Authors.
2+
// Copyright 2024-2025 The Chainloop Authors.
33
//
44
// Licensed under the Apache License, Version 2.0 (the "License");
55
// you may not use this file except in compliance with the License.
@@ -59,8 +59,34 @@ const (
5959
devUserDuration = 30 * longLivedDuration
6060
)
6161

62+
type oauthResp struct {
63+
code int
64+
err error
65+
showErrToUser bool
66+
}
67+
68+
// This is used to provide by default a generic error message to the user
69+
// unless showErrToUser is true
70+
func (e *oauthResp) ErrorMessage(l *log.Helper) string {
71+
if e.err != nil {
72+
// If the error is an internal server error, log it and raise it masked
73+
if e.code == http.StatusInternalServerError {
74+
return sl.LogAndMaskErr(e.err, l).Error()
75+
}
76+
// otherwise return the error message to the user
77+
// or the default status text
78+
if e.showErrToUser {
79+
return e.err.Error()
80+
}
81+
82+
return http.StatusText(e.code)
83+
}
84+
85+
return ""
86+
}
87+
6288
type oauthHandler struct {
63-
H func(*AuthService, http.ResponseWriter, *http.Request) (int, error)
89+
H func(*AuthService, http.ResponseWriter, *http.Request) *oauthResp
6490
svc *AuthService
6591
}
6692

@@ -164,17 +190,16 @@ func (svc *AuthService) RegisterLoginHandler() http.Handler {
164190

165191
// Implement http.Handler interface
166192
func (h oauthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
167-
status, err := h.H(h.svc, w, r)
168-
if err != nil {
169-
http.Error(w, http.StatusText(status), status)
193+
if err := h.H(h.svc, w, r); err != nil {
194+
http.Error(w, err.ErrorMessage(h.svc.log), err.code)
170195
}
171196
}
172197

173-
func loginHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (int, error) {
198+
func loginHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) *oauthResp {
174199
b := make([]byte, 16)
175200
_, err := rand.Read(b)
176201
if err != nil {
177-
return http.StatusInternalServerError, sl.LogAndMaskErr(err, nil)
202+
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to generate random string: %w", err), false}
178203
}
179204

180205
// Store a random string to check it in the oauth callback
@@ -196,7 +221,7 @@ func loginHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (int
196221
if connectionStr != "" {
197222
uri, err := url.Parse(authorizationURI)
198223
if err != nil {
199-
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
224+
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to parse authorization URI: %w", err), false}
200225
}
201226
q := uri.Query()
202227
q.Set("connection", connectionStr)
@@ -205,7 +230,7 @@ func loginHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (int
205230
}
206231

207232
http.Redirect(w, r, authorizationURI, http.StatusFound)
208-
return http.StatusTemporaryRedirect, nil
233+
return &oauthResp{http.StatusTemporaryRedirect, nil, false}
209234
}
210235

211236
// Extract custom claims
@@ -231,35 +256,35 @@ func (c *upstreamOIDCclaims) preferredEmail() string {
231256
return c.Email
232257
}
233258

234-
type errorWithCode struct {
235-
code int
236-
error
237-
}
238-
239-
func callbackHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (int, error) {
259+
func callbackHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) *oauthResp {
240260
ctx := context.Background()
261+
// if OIDC provider returns an error, return it to and show it to the user
262+
if desc := r.URL.Query().Get("error_description"); desc != "" {
263+
return &oauthResp{http.StatusUnauthorized, errors.New(desc), true}
264+
}
265+
241266
// Get information from google OIDC token
242267
claims, errWithCode := extractUserInfoFromToken(ctx, svc, r)
243268
if errWithCode != nil {
244-
return errWithCode.code, sl.LogAndMaskErr(errWithCode.error, svc.log)
269+
return &oauthResp{errWithCode.code, errWithCode.err, errWithCode.showErrToUser}
245270
}
246271

247272
// Create user if needed
248273
u, err := svc.userUseCase.FindOrCreateByEmail(ctx, claims.preferredEmail())
249274
if err != nil {
250-
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
275+
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to find or create user: %w", err), false}
251276
}
252277

253278
// Accept any pending invites
254279
if err := svc.orgInvitesUseCase.AcceptPendingInvitations(ctx, u.Email); err != nil {
255-
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
280+
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to accept pending invitations: %w", err), false}
256281
}
257282

258283
// Set the expiration
259284
expiration := shortLivedDuration
260285
longLived, err := r.Cookie(cookieLongLived)
261286
if err != nil {
262-
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
287+
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to get long lived cookie: %w", err), false}
263288
}
264289

265290
if longLived.Value == "true" {
@@ -269,32 +294,32 @@ func callbackHandler(svc *AuthService, w http.ResponseWriter, r *http.Request) (
269294
// Generate user token
270295
userToken, err := generateUserJWT(u.ID, svc.authConfig.GeneratedJwsHmacSecret, expiration)
271296
if err != nil {
272-
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
297+
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to generate user token: %w", err), false}
273298
}
274299

275300
// Either redirect or render the token if fallback is specified
276301
// Callback URL from the cookie
277302
callbackURLFromCookie, err := r.Cookie(cookieCallback)
278303
if err != nil {
279-
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
304+
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to get callback URL from cookie: %w", err), false}
280305
}
281306

282307
callbackValue := callbackURLFromCookie.Value
283308

284309
// There is no callback, just render the token
285310
if callbackValue == "" {
286311
fmt.Fprintf(w, "copy this token and paste it in your terminal window\n\n%s", userToken)
287-
return http.StatusOK, nil
312+
return &oauthResp{http.StatusOK, nil, false}
288313
}
289314

290315
// Redirect to the callback URL
291316
callbackURL, err := crafCallbackURL(callbackValue, userToken)
292317
if err != nil {
293-
return http.StatusInternalServerError, sl.LogAndMaskErr(err, svc.log)
318+
return &oauthResp{http.StatusInternalServerError, fmt.Errorf("failed to craft callback URL: %w", err), false}
294319
}
295320

296321
http.Redirect(w, r, callbackURL, http.StatusFound)
297-
return http.StatusTemporaryRedirect, nil
322+
return &oauthResp{http.StatusTemporaryRedirect, nil, false}
298323
}
299324

300325
func crafCallbackURL(callback, userToken string) (string, error) {
@@ -311,14 +336,15 @@ func crafCallbackURL(callback, userToken string) (string, error) {
311336
}
312337

313338
// Returns the claims from the OIDC token received during the OIDC callback
314-
func extractUserInfoFromToken(ctx context.Context, svc *AuthService, r *http.Request) (*upstreamOIDCclaims, *errorWithCode) {
339+
func extractUserInfoFromToken(ctx context.Context, svc *AuthService, r *http.Request) (*upstreamOIDCclaims, *oauthResp) {
315340
cookieState, err := r.Cookie(cookieOauthStateName)
341+
// if the cookie is not found, it likely means the authentication process has expired
316342
if err != nil {
317-
return nil, &errorWithCode{http.StatusUnauthorized, fmt.Errorf("retrieving cookie %s: %w", cookieOauthStateName, err)}
343+
return nil, &oauthResp{http.StatusUnauthorized, errors.New("the authentication process has expired, please try again"), true}
318344
}
319345

320346
if r.URL.Query().Get("state") != cookieState.Value {
321-
return nil, &errorWithCode{http.StatusUnauthorized, errors.New("oauth state does not match")}
347+
return nil, &oauthResp{http.StatusUnauthorized, errors.New("the authentication was invalid, please try again"), true}
322348
}
323349

324350
code := r.URL.Query().Get("code")
@@ -329,23 +355,23 @@ func extractUserInfoFromToken(ctx context.Context, svc *AuthService, r *http.Req
329355
// Exchange the code for a token
330356
oauth2Token, err := svc.authenticator.Exchange(ctx, code)
331357
if err != nil {
332-
return nil, &errorWithCode{http.StatusUnauthorized, err}
358+
return nil, &oauthResp{http.StatusUnauthorized, err, false}
333359
}
334360

335361
// It's a valid Oauth2 token
336362
if !oauth2Token.Valid() {
337-
return nil, &errorWithCode{http.StatusUnauthorized, errors.New("retrieved invalid Token")}
363+
return nil, &oauthResp{http.StatusUnauthorized, errors.New("retrieved invalid Token"), false}
338364
}
339365

340366
// Parse and verify ID token content and signature
341367
idToken, err := svc.authenticator.VerifyIDToken(ctx, oauth2Token)
342368
if err != nil {
343-
return nil, &errorWithCode{http.StatusInternalServerError, err}
369+
return nil, &oauthResp{http.StatusInternalServerError, err, false}
344370
}
345371

346372
var claims *upstreamOIDCclaims
347373
if err := idToken.Claims(&claims); err != nil {
348-
return nil, &errorWithCode{http.StatusInternalServerError, err}
374+
return nil, &oauthResp{http.StatusInternalServerError, err, false}
349375
}
350376

351377
return claims, nil
@@ -367,7 +393,7 @@ func generateUserJWT(userID, passphrase string, expiration time.Duration) (strin
367393
}
368394

369395
func setOauthCookie(w http.ResponseWriter, name, value string) {
370-
http.SetCookie(w, &http.Cookie{Name: name, Value: value, Path: "/", Expires: time.Now().Add(5 * time.Minute)})
396+
http.SetCookie(w, &http.Cookie{Name: name, Value: value, Path: "/", Expires: time.Now().Add(10 * time.Minute)})
371397
}
372398

373399
func generateAndLogDevUser(userUC *biz.UserUseCase, log *log.Helper, authConfig *conf.Auth) error {

0 commit comments

Comments
 (0)