Skip to content

Commit

Permalink
Merge pull request #8 from ZiyamSanthosh/main-servlet-exception
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyamSanthosh authored Feb 7, 2025
2 parents fdb3992 + 499cbd3 commit 67aa3de
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.wso2.carbon.identity.local.auth.push.servlet;

import com.google.gson.Gson;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
import org.apache.commons.lang.StringUtils;
Expand All @@ -29,6 +30,7 @@
import org.wso2.carbon.identity.local.auth.push.servlet.constant.PushServletConstants;
import org.wso2.carbon.identity.local.auth.push.servlet.impl.PushAuthStatusCacheManagerImpl;
import org.wso2.carbon.identity.local.auth.push.servlet.internal.PushServletDataHolder;
import org.wso2.carbon.identity.local.auth.push.servlet.model.ServletApiError;
import org.wso2.carbon.identity.notification.push.common.PushChallengeValidator;
import org.wso2.carbon.identity.notification.push.common.exception.PushTokenValidationException;
import org.wso2.carbon.identity.notification.push.device.handler.exception.PushDeviceHandlerException;
Expand All @@ -37,7 +39,6 @@
import java.io.IOException;
import java.text.ParseException;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand All @@ -62,60 +63,83 @@ public class PushAuthServlet extends HttpServlet {
*
* @param request HTTP request
* @param response HTTP response
* @throws ServletException if an error occurs when handling the request
* @throws IOException if an I/O error occurs when handling the request
*/
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
protected void doPost(HttpServletRequest request, HttpServletResponse response) {

handleDeviceResponse(request, response);
try {
handleDeviceResponse(request, response);
} catch (IOException e) {
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_INTERNAL_SERVER_ERROR;
log.error(error.getDescription(), e);
handleAPIErrorResponse(response, error, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
}
}

/**
* Handle the authentication response sent from the device.
*
* @param request HTTP request
* @param response HTTP response
* @throws IOException if an I/O error occurs when handling the request
* @throws ServletException if an error occurs when handling the request
* @throws IOException if an I/O error occurs when handling the request.
*/
private void handleDeviceResponse(HttpServletRequest request, HttpServletResponse response)
throws IOException, ServletException {
private void handleDeviceResponse(HttpServletRequest request, HttpServletResponse response) throws IOException {

JSONObject jsonContent = readJsonContentInRequest(request);
String token = jsonContent.getString(AUTH_RESPONSE);
JSONObject jsonContent = readJsonContentInRequest(request, response);
if (jsonContent == null) {
return;
}

String token = jsonContent.getString(AUTH_RESPONSE);
if (StringUtils.isBlank(token)) {

PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_AUTH_RESPONSE_TOKEN_NOT_FOUND;
if (log.isDebugEnabled()) {
log.debug(PushServletConstants.ErrorMessages.ERROR_CODE_AUTH_RESPONSE_TOKEN_NOT_FOUND.toString());
log.debug(error.getDescription());
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);

response.sendError(HttpServletResponse.SC_BAD_REQUEST,
PushServletConstants.ErrorMessages.ERROR_CODE_AUTH_RESPONSE_TOKEN_NOT_FOUND.toString());
} else {

String deviceId = getDeviceIdFromToken(token);
JWTClaimsSet claimsSet = getClaimsSetFromAuthToken(token, deviceId);
String deviceId = getDeviceIdFromToken(token, response);
if (StringUtils.isBlank(deviceId)) {
return;
}

JWTClaimsSet claimsSet = getClaimsSetFromAuthToken(token, deviceId, response);
if (claimsSet == null) {
return;
}

String pushAuthId;
try {
pushAuthId = claimsSet.getStringClaim(PushServletConstants.TOKEN_PUSH_AUTH_ID);
} catch (ParseException e) {
throw new ServletException(PushServletConstants.ErrorMessages.ERROR_CODE_PARSE_JWT_FAILED.toString());
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_PARSE_JWT_FAILED;
if (log.isDebugEnabled()) {
log.debug(error.getDescription(), e);
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
return;
}

if (StringUtils.isBlank(pushAuthId)) {

String errorMessage = String.format(
PushServletConstants.ErrorMessages.ERROR_CODE_PUSH_AUTH_ID_NOT_FOUND.toString(), deviceId);
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_PUSH_AUTH_ID_NOT_FOUND;
if (log.isDebugEnabled()) {
log.debug(errorMessage);
log.debug(String.format(error.getDescription(), deviceId));
}
response.sendError(HttpServletResponse.SC_BAD_REQUEST, errorMessage);
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
} else {

addToContext(pushAuthId, token);
boolean isSuccessful = addToContext(pushAuthId, token, response);
if (!isSuccessful) {
return;
}
String status = PushServletConstants.Status.COMPLETED.name();
pushAuthStatusCacheManager.storeStatusCache(pushAuthId, status);

Expand All @@ -132,9 +156,10 @@ private void handleDeviceResponse(HttpServletRequest request, HttpServletRespons
* Read the JSON content in the request.
*
* @param request HTTP request
* @param response HTTP response
* @return JSON content in the request
*/
private JSONObject readJsonContentInRequest(HttpServletRequest request) throws ServletException {
private JSONObject readJsonContentInRequest(HttpServletRequest request, HttpServletResponse response) {

StringBuilder stringBuilder = new StringBuilder();
String line;
Expand All @@ -143,8 +168,11 @@ private JSONObject readJsonContentInRequest(HttpServletRequest request) throws S
stringBuilder.append(line);
}
} catch (IOException e) {
throw new ServletException(PushServletConstants
.ErrorMessages.ERROR_CODE_REQUEST_CONTENT_READ_FAILED.toString(), e);
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_REQUEST_CONTENT_READ_FAILED;
log.error(error.getDescription(), e);
handleAPIErrorResponse(response, error, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
return null;
}
String jsonString = stringBuilder.toString();
return new JSONObject(jsonString);
Expand All @@ -154,17 +182,22 @@ private JSONObject readJsonContentInRequest(HttpServletRequest request) throws S
* Derive the Device ID from the auth response token header.
*
* @param token Auth response token
* @param response HTTP response
* @return Device ID
* @throws ServletException if the token string fails to parse to JWT
*/
private String getDeviceIdFromToken(String token) throws ServletException {
private String getDeviceIdFromToken(String token, HttpServletResponse response) {

try {
return String.valueOf(JWTParser.parse(token).getHeader().getCustomParam(
PushServletConstants.TOKEN_DEVICE_ID));
} catch (ParseException e) {
throw new ServletException(PushServletConstants
.ErrorMessages.ERROR_CODE_GET_DEVICE_ID_FAILED.toString(), e);
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_GET_DEVICE_ID_FAILED;
if (log.isDebugEnabled()) {
log.debug(error.getDescription(), e);
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
return null;
}
}

Expand All @@ -173,22 +206,28 @@ private String getDeviceIdFromToken(String token) throws ServletException {
*
* @param token Auth response token
* @param deviceId Device ID
* @param response HTTP response
* @return JWTClaimsSet
* @throws ServletException if the public key cannot be retrieved or the token validation fails
*/
private JWTClaimsSet getClaimsSetFromAuthToken(String token, String deviceId) throws ServletException {
private JWTClaimsSet getClaimsSetFromAuthToken(String token, String deviceId, HttpServletResponse response) {

try {
String publicKey = PushServletDataHolder.getInstance().getDeviceHandlerService().getPublicKey(deviceId);
return PushChallengeValidator.getValidatedClaimSet(token, publicKey);
} catch (PushDeviceHandlerException e) {
String errorMessage = String.format(PushServletConstants
.ErrorMessages.ERROR_CODE_GET_PUBLIC_KEY_FAILED.toString(), deviceId);
throw new ServletException(errorMessage);
PushServletConstants.ErrorMessages error
= PushServletConstants.ErrorMessages.ERROR_CODE_GET_PUBLIC_KEY_FAILED;
log.error(String.format(error.getDescription(), deviceId), e);
handleAPIErrorResponse(response, error, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
return null;
} catch (PushTokenValidationException e) {
String errorMessage = String.format(PushServletConstants
.ErrorMessages.ERROR_CODE_TOKEN_VALIDATION_FAILED.toString(), deviceId);
throw new ServletException(errorMessage);
PushServletConstants.ErrorMessages error
= PushServletConstants.ErrorMessages.ERROR_CODE_TOKEN_VALIDATION_FAILED;
if (log.isDebugEnabled()) {
log.debug(String.format(error.getDescription(), deviceId), e);
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
return null;
}
}

Expand All @@ -197,12 +236,44 @@ private JWTClaimsSet getClaimsSetFromAuthToken(String token, String deviceId) th
*
* @param pushAuthId Push authentication ID
* @param token Auth response token
* @param response HTTP response
*/
private void addToContext(String pushAuthId, String token) {
private boolean addToContext(String pushAuthId, String token, HttpServletResponse response) {

PushAuthContextManager contextManager = PushServletDataHolder.getInstance().getPushAuthContextManager();
PushAuthContext context = contextManager.getContext(pushAuthId);
if (context == null) {
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_ERROR_AUTH_CONTEXT_NOT_FOUND;
if (log.isDebugEnabled()) {
log.debug(String.format(error.getDescription(), pushAuthId));
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
return false;
}
context.setToken(token);
contextManager.storeContext(pushAuthId, context);
return true;
}

/**
* Handle the API error response.
*
* @param response HTTP response
* @param error Error message
* @param statusCode HTTP status code
*/
private void handleAPIErrorResponse(HttpServletResponse response, PushServletConstants.ErrorMessages error,
int statusCode) {

try {
response.setStatus(statusCode);
ServletApiError servletApiError = new ServletApiError(error.getCode(), error.getMessage());
String jsonResponse = new Gson().toJson(servletApiError);
response.getWriter().write(jsonResponse);
response.getWriter().flush();
} catch (IOException e) {
log.error("Error occurred while sending the error response.", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.wso2.carbon.identity.local.auth.push.servlet.constant.PushServletConstants;
import org.wso2.carbon.identity.local.auth.push.servlet.impl.PushAuthStatusCacheManagerImpl;
import org.wso2.carbon.identity.local.auth.push.servlet.model.PushAuthStatus;
import org.wso2.carbon.identity.local.auth.push.servlet.model.ServletApiError;

import java.io.IOException;

Expand All @@ -47,17 +48,25 @@ public class PushStatusServlet extends HttpServlet {
private static final PushAuthStatusCacheManager pushAuthStatusCacheManager = new PushAuthStatusCacheManagerImpl();

@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException {
protected void doGet(HttpServletRequest request, HttpServletResponse response) {

if (!(request.getParameterMap().containsKey(AuthenticatorConstants.PUSH_AUTH_ID))) {
response.setStatus(HttpServletResponse.SC_NOT_FOUND);

if (log.isDebugEnabled()) {
log.debug(PushServletConstants.ErrorMessages.ERROR_CODE_PUSH_AUTH_ID_NOT_FOUND_IN_STATUS.toString());
}
handleAPIErrorResponse(response,
PushServletConstants.ErrorMessages.ERROR_CODE_PUSH_AUTH_ID_NOT_FOUND_IN_STATUS,
HttpServletResponse.SC_NOT_FOUND);

} else {
handleWebResponse(request, response);
try {
handleWebResponse(request, response);
} catch (IOException e) {
log.error("Error occurred while handling the push auth status response..", e);
handleAPIErrorResponse(response, PushServletConstants.ErrorMessages.ERROR_CODE_INTERNAL_SERVER_ERROR,
HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
}
}
}

Expand Down Expand Up @@ -93,5 +102,27 @@ private void handleWebResponse(HttpServletRequest request, HttpServletResponse r

String jsonResponse = new Gson().toJson(pushAuthStatus);
response.getWriter().write(jsonResponse);
response.getWriter().flush();
}

/**
* Handle the API error response.
*
* @param response HTTP response
* @param error Error message
* @param statusCode HTTP status code
*/
private void handleAPIErrorResponse(HttpServletResponse response, PushServletConstants.ErrorMessages error,
int statusCode) {

try {
response.setStatus(statusCode);
ServletApiError servletApiError = new ServletApiError(error.getCode(), error.getMessage());
String jsonResponse = new Gson().toJson(servletApiError);
response.getWriter().write(jsonResponse);
response.getWriter().flush();
} catch (IOException e) {
log.error("Error occurred while sending the error response.", e);
}
}
}
Loading

0 comments on commit 67aa3de

Please sign in to comment.