Skip to content

Commit

Permalink
OIDC UserInfo Endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Stephen Crawford <[email protected]>
  • Loading branch information
stephen-crawford committed Aug 26, 2024
1 parent b601a92 commit c837dcf
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 80 deletions.
5 changes: 4 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,12 @@ dependencies {
runtimeOnly 'jakarta.xml.bind:jakarta.xml.bind-api:4.0.2'
runtimeOnly 'org.ow2.asm:asm:9.7'

implementation 'com.nimbusds:oauth2-oidc-sdk:11.18'
implementation 'net.minidev:json-smart:2.5.1'
implementation 'com.nimbusds:content-type:2.3'

testImplementation 'org.apache.camel:camel-xmlsecurity:3.22.2'
testImplementation 'org.mockito:mockito-inline:5.2.0'
//testImplementation 'org.mockito:mockito-inline:2.13.0'
//OpenSAML
implementation 'net.shibboleth.utilities:java-support:8.4.2'
runtimeOnly "io.dropwizard.metrics:metrics-core:4.2.27"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,18 @@
package com.amazon.dlic.auth.http.jwt.keybyoidc;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.text.ParseException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import org.apache.hc.client5.http.classic.methods.HttpGet;
import org.apache.hc.client5.http.config.RequestConfig;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder;
import org.apache.hc.client5.http.io.HttpClientConnectionManager;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand All @@ -43,12 +37,18 @@

import com.amazon.dlic.auth.http.jwt.AbstractHTTPJwtAuthenticator;
import com.amazon.dlic.util.SettingsBasedSSLConfigurator;
import com.nimbusds.common.contenttype.ContentType;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.AccessTokenType;
import com.nimbusds.oauth2.sdk.util.StringUtils;
import com.nimbusds.openid.connect.sdk.UserInfoRequest;
import com.nimbusds.openid.connect.sdk.UserInfoResponse;

import static org.apache.hc.core5.http.ContentType.APPLICATION_JSON;
import static org.apache.hc.core5.http.HttpHeaders.AUTHORIZATION;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.APPLICATION_JWT;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.CLIENT_ID;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.ISSUER_ID_URL;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.SUB_CLAIM;
Expand Down Expand Up @@ -139,67 +139,58 @@ public CloseableHttpClient createHttpClient() {
*/
public AuthCredentials extractCredentials0(SecurityRequest request, ThreadContext context) throws OpenSearchSecurityException {

try (CloseableHttpClient httpClient = createHttpClient()) {
try {

HttpGet httpGet = new HttpGet(this.userInfoEndpoint);
URI userInfoEndpointURI = new URI(this.userInfoEndpoint);

RequestConfig requestConfig = RequestConfig.custom()
.setConnectionRequestTimeout(requestTimeoutMs, TimeUnit.MILLISECONDS)
.setConnectTimeout(requestTimeoutMs, TimeUnit.MILLISECONDS)
.build();

httpGet.setConfig(requestConfig);
httpGet.addHeader(AUTHORIZATION, request.getHeaders().get(AUTHORIZATION));
String bearerHeader = request.getHeaders().get(AUTHORIZATION).getFirst();
if (!StringUtils.isBlank(bearerHeader)) {
if (bearerHeader.contains("Bearer ")) {
bearerHeader = bearerHeader.substring(7);
}
}

// HTTPGet should internally verify the appropriate TLS cert.
try (CloseableHttpResponse response = httpClient.execute(httpGet)) {
String finalBearerHeader = bearerHeader;

if (response.getCode() < 200 || response.getCode() >= 300) {
throw new AuthenticatorUnavailableException(
"Error while getting " + this.userInfoEndpoint + ": " + response.getReasonPhrase()
);
AccessToken accessToken = new AccessToken(AccessTokenType.BEARER, finalBearerHeader) {
@Override
public String toAuthorizationHeader() {
return "Bearer " + finalBearerHeader;
}
};

HttpEntity httpEntity = response.getEntity();
UserInfoRequest userInfoRequest = new UserInfoRequest(userInfoEndpointURI, accessToken);

if (httpEntity == null) {
throw new AuthenticatorUnavailableException("Error while getting " + this.userInfoEndpoint + ": Empty response entity");
}
HTTPRequest httpRequest = userInfoRequest.toHTTPRequest();

String contentType = httpEntity.getContentType();
if (!contentType.contains(APPLICATION_JSON.getMimeType()) && !contentType.contains(APPLICATION_JWT)) {
throw new AuthenticatorUnavailableException(
"Error while getting " + this.userInfoEndpoint + ": Invalid content type in response"
);
}
HTTPResponse httpResponse = httpRequest.send();
if (httpResponse.getStatusCode() < 200 || httpResponse.getStatusCode() >= 300) {
throw new AuthenticatorUnavailableException(
"Error while getting " + this.userInfoEndpoint + ": " + httpResponse.getStatusMessage()
);
}

try {

UserInfoResponse userInfoResponse = UserInfoResponse.parse(httpResponse);

String userinfoContent;

try (
// got this from ChatGpt & Amazon Q
InputStream inputStream = httpEntity.getContent();
InputStreamReader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)
) {
StringBuilder content = new StringBuilder();
char[] buffer = new char[8192];
int bytesRead;
while ((bytesRead = reader.read(buffer)) != -1) {
content.append(buffer, 0, bytesRead);
}
userinfoContent = content.toString();
} catch (IOException e) {
if (!userInfoResponse.indicatesSuccess()) {
throw new AuthenticatorUnavailableException(
"Error while getting " + this.userInfoEndpoint + ": Unable to read response content"
"Error while getting " + this.userInfoEndpoint + ": " + userInfoResponse.toErrorResponse()
);
}

String contentType = String.valueOf(httpResponse.getHeaderValues("content-type"));

JWTClaimsSet claims;
boolean isSigned = contentType.contains(APPLICATION_JWT);
if (contentType.contains(APPLICATION_JWT)) { // We don't need the userinfo_encrypted_response_alg since the
// selfRefreshingKeyProvider has access to the keys
claims = openIdJwtAuthenticator.getJwtClaimsSetFromInfoContent(userinfoContent);
boolean isSigned = contentType.contains(ContentType.APPLICATION_JWT.toString());
if (isSigned) { // We don't need the userinfo_encrypted_response_alg since the
// selfRefreshingKeyProvider has access to the keys
claims = openIdJwtAuthenticator.getJwtClaimsSetFromInfoContent(
userInfoResponse.toSuccessResponse().getUserInfoJWT().getParsedString()
);
} else {
claims = JWTClaimsSet.parse(userinfoContent);
claims = JWTClaimsSet.parse(userInfoResponse.toSuccessResponse().getUserInfo().toString());
}

String id = openIdJwtAuthenticator.getJwtClaimsSet(request).getSubject();
Expand Down Expand Up @@ -228,7 +219,7 @@ public AuthCredentials extractCredentials0(SecurityRequest request, ThreadContex
} catch (ParseException e) {
throw new RuntimeException(e);
}
} catch (IOException e) {
} catch (IOException | URISyntaxException | com.nimbusds.oauth2.sdk.ParseException e) {
throw new AuthenticatorUnavailableException("Error while getting " + this.userInfoEndpoint + ": " + e, e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

public class OpenIdConstants {

public static final String APPLICATION_JWT = "application/jwt";
public static final String CLIENT_ID = "client_id";
public static final String ISSUER_ID_URL = "issuer_id_url";
public static final String SUB_CLAIM = "sub";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@
import org.opensearch.security.user.AuthCredentials;
import org.opensearch.security.util.FakeRestRequest;

import com.nimbusds.common.contenttype.ContentType;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.APPLICATION_JWT;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.CLIENT_ID;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.ISSUER_ID_URL;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.MCCOY_SUBJECT;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.OIDC_TEST_AUD;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.OIDC_TEST_ISS;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.ROLES_CLAIM;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.STEPHEN_SUBJECT;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.spy;

Expand Down Expand Up @@ -435,7 +437,12 @@ public void userinfoEndpointReturnsJwtWithAllRequirementsTest() throws Exception

AuthCredentials creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.MC_COY_SIGNED_OCT_1,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand All @@ -448,8 +455,8 @@ public void userinfoEndpointReturnsJwtWithAllRequirementsTest() throws Exception

@Test
public void userinfoEndpointReturnsJwtWithRequiredAudIssFailsTest() throws Exception { // Setting a required issuer or audience
// alongside userinfo endpoint settings causes
// failures in signed response cases
// alongside userinfo endpoint settings causes
// failures in signed response cases
Settings settings = Settings.builder()
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
.put("userinfo_endpoint", mockIdpServer.getUserinfoSignedUri())
Expand All @@ -466,7 +473,12 @@ public void userinfoEndpointReturnsJwtWithRequiredAudIssFailsTest() throws Excep
try {
creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.MC_COY_SIGNED_OCT_1,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand All @@ -493,7 +505,12 @@ public void userinfoEndpointReturnsJwtWithMatchingRequiredAudIssPassesTest() thr

AuthCredentials creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1_OIDC, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.MC_COY_SIGNED_OCT_1_OIDC,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand All @@ -520,7 +537,12 @@ public void userinfoEndpointReturnsJwtMissingIssuerTest() throws Exception {
try {
creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.MC_COY_SIGNED_OCT_1,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand Down Expand Up @@ -548,7 +570,12 @@ public void userinfoEndpointReturnsJwtMissingAudienceTest() throws Exception {
try {
creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.MC_COY_SIGNED_OCT_1,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand All @@ -575,7 +602,12 @@ public void userinfoEndpointReturnsJwtMismatchedSubTest() throws Exception {
try {
creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.STEPHEN_RSA_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.STEPHEN_RSA_1,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand All @@ -600,7 +632,12 @@ public void userinfoEndpointReturnsJsonWithAllRequirementsTest() throws Exceptio

AuthCredentials creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.MC_COY_SIGNED_OCT_1,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand All @@ -626,7 +663,12 @@ public void userinfoEndpointReturnsJsonMismatchedSubTest() throws Exception {
try {
creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.STEPHEN_RSA_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.STEPHEN_RSA_1,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand All @@ -642,7 +684,7 @@ public void userinfoEndpointReturnsJsonMismatchedSubTest() throws Exception {
public void userinfoEndpointReturnsResponseNot2xxTest() throws Exception {
Settings settings = Settings.builder()
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
.put("userinfo_endpoint", mockIdpServer.getUserinfoUri())
.put("userinfo_endpoint", mockIdpServer.getBadUserInfoUri())
.put("required_issuer", TestJwts.TEST_ISSUER)
.put("required_audience", TestJwts.TEST_AUDIENCE + ",another_audience")
.build();
Expand All @@ -653,7 +695,7 @@ public void userinfoEndpointReturnsResponseNot2xxTest() throws Exception {
try {
creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of("Authorization", STEPHEN_SUBJECT, "Content-Type", ContentType.APPLICATION_JWT.toString()),
new HashMap<>()
).asSecurityRequest(),
null
Expand All @@ -680,7 +722,12 @@ public void userinfoEndpointReturnsJsonWithRequiredAudIssPassesTest() throws Exc

AuthCredentials creds = openIdAuthenticator.extractCredentials(
new FakeRestRequest(
ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT),
ImmutableMap.of(
"Authorization",
"Bearer " + TestJwts.MC_COY_SIGNED_OCT_1,
"Content-Type",
ContentType.APPLICATION_JWT.toString()
),
new HashMap<>()
).asSecurityRequest(),
null
Expand Down
Loading

0 comments on commit c837dcf

Please sign in to comment.