From 543664a8073d14657bd4d24188b02aba6df1cfe5 Mon Sep 17 00:00:00 2001 From: Lucian Holland Date: Mon, 3 Feb 2025 12:31:23 +0100 Subject: [PATCH] Proposed fix for missing WWW-Authenticate header Current implementation does not include the WWW-Authenticate header when returning a 401 for missing/invalid credentials when attempting to access the token endpoints. Fixes-468 Signed-off-by: Lucian Holland --- .../web/OAuth2ClientAuthenticationFilter.java | 36 +++++++++++-------- ...OAuth2ClientAuthenticationFilterTests.java | 8 +++-- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java index f074534ea..d70e0826b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java @@ -24,6 +24,7 @@ import jakarta.servlet.http.HttpServletResponse; import org.springframework.core.log.LogMessage; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; @@ -96,6 +97,14 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure; + /** + * Internal error code used to distinguish missing authentication from invalid + * authentication in order to display a WWW-Authenticate header when appropriate The + * default failure handler will convert this to the spec-compliant 'invalid_client' + * before returning to the caller. + */ + private static final String MISSING_CLIENT_AUTH_ERROR_CODE = "missing_client_auth"; + /** * Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided * parameters. @@ -140,9 +149,12 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse validateClientIdentifier(authenticationRequest); Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest); this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult); + filterChain.doFilter(request, response); + } + else { + this.authenticationFailureHandler.onAuthenticationFailure(request, response, + new OAuth2AuthenticationException(MISSING_CLIENT_AUTH_ERROR_CODE)); } - filterChain.doFilter(request, response); - } catch (OAuth2AuthenticationException ex) { if (this.logger.isTraceEnabled()) { @@ -204,27 +216,23 @@ private void onAuthenticationFailure(HttpServletRequest request, HttpServletResp SecurityContextHolder.clearContext(); - // TODO - // The authorization server MAY return an HTTP 401 (Unauthorized) status code - // to indicate which HTTP authentication schemes are supported. - // If the client attempted to authenticate via the "Authorization" request header - // field, - // the authorization server MUST respond with an HTTP 401 (Unauthorized) status - // code and - // include the "WWW-Authenticate" response header field - // matching the authentication scheme used by the client. - OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) { + String errorCode = error.getErrorCode(); + String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION); + + if (MISSING_CLIENT_AUTH_ERROR_CODE.equals(errorCode) || (OAuth2ErrorCodes.INVALID_CLIENT.equals(errorCode) + && authHeader != null && authHeader.trim().startsWith("Basic"))) { + errorCode = OAuth2ErrorCodes.INVALID_CLIENT; httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED); + httpResponse.getHeaders().set("WWW-Authenticate", "Basic"); } else { httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); } // We don't want to reveal too much information to the caller so just return the // error code - OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode()); + OAuth2Error errorResponse = new OAuth2Error(errorCode); this.errorHttpResponseConverter.write(errorResponse, null, httpResponse); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java index 97dc1750b..a50ee5cb6 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java @@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; @@ -149,8 +150,10 @@ public void doFilterWhenRequestMatchesAndEmptyCredentialsThenNotProcessed() thro this.filter.doFilter(request, response, filterChain); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - verifyNoInteractions(this.authenticationManager); + verifyNoInteractions(this.authenticationManager, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test @@ -225,6 +228,7 @@ public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl); request.setServletPath(this.filterProcessesUrl); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic invalid-secret"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class);