Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify push servlets to catch and handle exceptions instead of throwing #8

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.wso2.carbon.identity.local.auth.push.servlet;

import com.google.gson.Gson;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
import org.apache.commons.lang.StringUtils;
Expand All @@ -29,6 +30,7 @@
import org.wso2.carbon.identity.local.auth.push.servlet.constant.PushServletConstants;
import org.wso2.carbon.identity.local.auth.push.servlet.impl.PushAuthStatusCacheManagerImpl;
import org.wso2.carbon.identity.local.auth.push.servlet.internal.PushServletDataHolder;
import org.wso2.carbon.identity.local.auth.push.servlet.model.ServletApiError;
import org.wso2.carbon.identity.notification.push.common.PushChallengeValidator;
import org.wso2.carbon.identity.notification.push.common.exception.PushTokenValidationException;
import org.wso2.carbon.identity.notification.push.device.handler.exception.PushDeviceHandlerException;
Expand All @@ -37,7 +39,6 @@
import java.io.IOException;
import java.text.ParseException;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand All @@ -62,60 +63,83 @@ public class PushAuthServlet extends HttpServlet {
*
* @param request HTTP request
* @param response HTTP response
* @throws ServletException if an error occurs when handling the request
* @throws IOException if an I/O error occurs when handling the request
*/
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
protected void doPost(HttpServletRequest request, HttpServletResponse response) {

handleDeviceResponse(request, response);
try {
handleDeviceResponse(request, response);
} catch (IOException e) {
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_INTERNAL_SERVER_ERROR;
log.error(error.getDescription(), e);
handleAPIErrorResponse(response, error, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
}
}

/**
* Handle the authentication response sent from the device.
*
* @param request HTTP request
* @param response HTTP response
* @throws IOException if an I/O error occurs when handling the request
* @throws ServletException if an error occurs when handling the request
* @throws IOException if an I/O error occurs when handling the request.
*/
private void handleDeviceResponse(HttpServletRequest request, HttpServletResponse response)
throws IOException, ServletException {
private void handleDeviceResponse(HttpServletRequest request, HttpServletResponse response) throws IOException {

JSONObject jsonContent = readJsonContentInRequest(request);
String token = jsonContent.getString(AUTH_RESPONSE);
JSONObject jsonContent = readJsonContentInRequest(request, response);
if (jsonContent == null) {
return;
}

String token = jsonContent.getString(AUTH_RESPONSE);
if (StringUtils.isBlank(token)) {

PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_AUTH_RESPONSE_TOKEN_NOT_FOUND;
if (log.isDebugEnabled()) {
log.debug(PushServletConstants.ErrorMessages.ERROR_CODE_AUTH_RESPONSE_TOKEN_NOT_FOUND.toString());
log.debug(error.getDescription());
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);

response.sendError(HttpServletResponse.SC_BAD_REQUEST,
PushServletConstants.ErrorMessages.ERROR_CODE_AUTH_RESPONSE_TOKEN_NOT_FOUND.toString());
} else {

String deviceId = getDeviceIdFromToken(token);
JWTClaimsSet claimsSet = getClaimsSetFromAuthToken(token, deviceId);
String deviceId = getDeviceIdFromToken(token, response);
if (StringUtils.isBlank(deviceId)) {
return;
}

JWTClaimsSet claimsSet = getClaimsSetFromAuthToken(token, deviceId, response);
if (claimsSet == null) {
return;
}

String pushAuthId;
try {
pushAuthId = claimsSet.getStringClaim(PushServletConstants.TOKEN_PUSH_AUTH_ID);
} catch (ParseException e) {
throw new ServletException(PushServletConstants.ErrorMessages.ERROR_CODE_PARSE_JWT_FAILED.toString());
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_PARSE_JWT_FAILED;
if (log.isDebugEnabled()) {
log.debug(error.getDescription(), e);
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
return;
}

if (StringUtils.isBlank(pushAuthId)) {

String errorMessage = String.format(
PushServletConstants.ErrorMessages.ERROR_CODE_PUSH_AUTH_ID_NOT_FOUND.toString(), deviceId);
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_PUSH_AUTH_ID_NOT_FOUND;
if (log.isDebugEnabled()) {
log.debug(errorMessage);
log.debug(String.format(error.getDescription(), deviceId));
}
response.sendError(HttpServletResponse.SC_BAD_REQUEST, errorMessage);
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
} else {

addToContext(pushAuthId, token);
boolean isSuccessful = addToContext(pushAuthId, token, response);
if (!isSuccessful) {
return;
}
String status = PushServletConstants.Status.COMPLETED.name();
pushAuthStatusCacheManager.storeStatusCache(pushAuthId, status);

Expand All @@ -132,9 +156,10 @@ private void handleDeviceResponse(HttpServletRequest request, HttpServletRespons
* Read the JSON content in the request.
*
* @param request HTTP request
* @param response HTTP response
* @return JSON content in the request
*/
private JSONObject readJsonContentInRequest(HttpServletRequest request) throws ServletException {
private JSONObject readJsonContentInRequest(HttpServletRequest request, HttpServletResponse response) {

StringBuilder stringBuilder = new StringBuilder();
String line;
Expand All @@ -143,8 +168,11 @@ private JSONObject readJsonContentInRequest(HttpServletRequest request) throws S
stringBuilder.append(line);
}
} catch (IOException e) {
throw new ServletException(PushServletConstants
.ErrorMessages.ERROR_CODE_REQUEST_CONTENT_READ_FAILED.toString(), e);
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_REQUEST_CONTENT_READ_FAILED;
log.error(error.getDescription(), e);
handleAPIErrorResponse(response, error, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
return null;
}
String jsonString = stringBuilder.toString();
return new JSONObject(jsonString);
Expand All @@ -154,17 +182,22 @@ private JSONObject readJsonContentInRequest(HttpServletRequest request) throws S
* Derive the Device ID from the auth response token header.
*
* @param token Auth response token
* @param response HTTP response
* @return Device ID
* @throws ServletException if the token string fails to parse to JWT
*/
private String getDeviceIdFromToken(String token) throws ServletException {
private String getDeviceIdFromToken(String token, HttpServletResponse response) {

try {
return String.valueOf(JWTParser.parse(token).getHeader().getCustomParam(
PushServletConstants.TOKEN_DEVICE_ID));
} catch (ParseException e) {
throw new ServletException(PushServletConstants
.ErrorMessages.ERROR_CODE_GET_DEVICE_ID_FAILED.toString(), e);
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_GET_DEVICE_ID_FAILED;
if (log.isDebugEnabled()) {
log.debug(error.getDescription(), e);
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
return null;
}
}

Expand All @@ -173,22 +206,28 @@ private String getDeviceIdFromToken(String token) throws ServletException {
*
* @param token Auth response token
* @param deviceId Device ID
* @param response HTTP response
* @return JWTClaimsSet
* @throws ServletException if the public key cannot be retrieved or the token validation fails
*/
private JWTClaimsSet getClaimsSetFromAuthToken(String token, String deviceId) throws ServletException {
private JWTClaimsSet getClaimsSetFromAuthToken(String token, String deviceId, HttpServletResponse response) {

try {
String publicKey = PushServletDataHolder.getInstance().getDeviceHandlerService().getPublicKey(deviceId);
return PushChallengeValidator.getValidatedClaimSet(token, publicKey);
} catch (PushDeviceHandlerException e) {
String errorMessage = String.format(PushServletConstants
.ErrorMessages.ERROR_CODE_GET_PUBLIC_KEY_FAILED.toString(), deviceId);
throw new ServletException(errorMessage);
PushServletConstants.ErrorMessages error
= PushServletConstants.ErrorMessages.ERROR_CODE_GET_PUBLIC_KEY_FAILED;
log.error(String.format(error.getDescription(), deviceId), e);
handleAPIErrorResponse(response, error, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
return null;
} catch (PushTokenValidationException e) {
String errorMessage = String.format(PushServletConstants
.ErrorMessages.ERROR_CODE_TOKEN_VALIDATION_FAILED.toString(), deviceId);
throw new ServletException(errorMessage);
PushServletConstants.ErrorMessages error
= PushServletConstants.ErrorMessages.ERROR_CODE_TOKEN_VALIDATION_FAILED;
if (log.isDebugEnabled()) {
log.debug(String.format(error.getDescription(), deviceId), e);
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
return null;
}
}

Expand All @@ -197,12 +236,44 @@ private JWTClaimsSet getClaimsSetFromAuthToken(String token, String deviceId) th
*
* @param pushAuthId Push authentication ID
* @param token Auth response token
* @param response HTTP response
*/
private void addToContext(String pushAuthId, String token) {
private boolean addToContext(String pushAuthId, String token, HttpServletResponse response) {

PushAuthContextManager contextManager = PushServletDataHolder.getInstance().getPushAuthContextManager();
PushAuthContext context = contextManager.getContext(pushAuthId);
if (context == null) {
PushServletConstants.ErrorMessages error =
PushServletConstants.ErrorMessages.ERROR_CODE_ERROR_AUTH_CONTEXT_NOT_FOUND;
if (log.isDebugEnabled()) {
log.debug(String.format(error.getDescription(), pushAuthId));
}
handleAPIErrorResponse(response, error, HttpServletResponse.SC_BAD_REQUEST);
return false;
}
context.setToken(token);
contextManager.storeContext(pushAuthId, context);
return true;
}

/**
* Handle the API error response.
*
* @param response HTTP response
* @param error Error message
* @param statusCode HTTP status code
*/
private void handleAPIErrorResponse(HttpServletResponse response, PushServletConstants.ErrorMessages error,
int statusCode) {

try {
response.setStatus(statusCode);
ServletApiError servletApiError = new ServletApiError(error.getCode(), error.getMessage());
String jsonResponse = new Gson().toJson(servletApiError);
response.getWriter().write(jsonResponse);
response.getWriter().flush();
} catch (IOException e) {
log.error("Error occurred while sending the error response.", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.wso2.carbon.identity.local.auth.push.servlet.constant.PushServletConstants;
import org.wso2.carbon.identity.local.auth.push.servlet.impl.PushAuthStatusCacheManagerImpl;
import org.wso2.carbon.identity.local.auth.push.servlet.model.PushAuthStatus;
import org.wso2.carbon.identity.local.auth.push.servlet.model.ServletApiError;

import java.io.IOException;

Expand All @@ -47,17 +48,25 @@ public class PushStatusServlet extends HttpServlet {
private static final PushAuthStatusCacheManager pushAuthStatusCacheManager = new PushAuthStatusCacheManagerImpl();

@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException {
protected void doGet(HttpServletRequest request, HttpServletResponse response) {

if (!(request.getParameterMap().containsKey(AuthenticatorConstants.PUSH_AUTH_ID))) {
response.setStatus(HttpServletResponse.SC_NOT_FOUND);

if (log.isDebugEnabled()) {
log.debug(PushServletConstants.ErrorMessages.ERROR_CODE_PUSH_AUTH_ID_NOT_FOUND_IN_STATUS.toString());
}
handleAPIErrorResponse(response,
PushServletConstants.ErrorMessages.ERROR_CODE_PUSH_AUTH_ID_NOT_FOUND_IN_STATUS,
HttpServletResponse.SC_NOT_FOUND);

} else {
handleWebResponse(request, response);
try {
handleWebResponse(request, response);
} catch (IOException e) {
log.error("Error occurred while handling the push auth status response..", e);
handleAPIErrorResponse(response, PushServletConstants.ErrorMessages.ERROR_CODE_INTERNAL_SERVER_ERROR,
HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
}
}
}

Expand Down Expand Up @@ -93,5 +102,27 @@ private void handleWebResponse(HttpServletRequest request, HttpServletResponse r

String jsonResponse = new Gson().toJson(pushAuthStatus);
response.getWriter().write(jsonResponse);
response.getWriter().flush();
}

/**
* Handle the API error response.
*
* @param response HTTP response
* @param error Error message
* @param statusCode HTTP status code
*/
private void handleAPIErrorResponse(HttpServletResponse response, PushServletConstants.ErrorMessages error,
int statusCode) {

try {
response.setStatus(statusCode);
ServletApiError servletApiError = new ServletApiError(error.getCode(), error.getMessage());
String jsonResponse = new Gson().toJson(servletApiError);
response.getWriter().write(jsonResponse);
response.getWriter().flush();
} catch (IOException e) {
log.error("Error occurred while sending the error response.", e);
}
}
}
Loading