From 0e6b5124238f70ee3d941c0c80f4f710b3ee023f Mon Sep 17 00:00:00 2001
From: Ivan Golovko <iigolovko@ginc-it.ru>
Date: Tue, 25 Feb 2025 23:40:21 +0300
Subject: [PATCH] Provided uncached way to use OidcIdTokenDecoderFactory

Signed-off-by: iigolovko <iigolovko@ginc-it.ru>
---
 .../OidcIdTokenDecoderFactory.java            | 44 ++++++++++++++-----
 .../ReactiveOidcIdTokenDecoderFactory.java    | 44 ++++++++++++++-----
 .../OidcIdTokenDecoderFactoryTests.java       | 19 ++++++++
 ...eactiveOidcIdTokenDecoderFactoryTests.java | 19 ++++++++
 4 files changed, 104 insertions(+), 22 deletions(-)

diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java
index 3679b7e36e5..aa1535c2054 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java
@@ -58,6 +58,7 @@
  * @author Joe Grandja
  * @author Rafael Dominguez
  * @author Mark Heckler
+ * @author Ivan Golovko
  * @since 5.2
  * @see JwtDecoderFactory
  * @see ClientRegistration
@@ -78,7 +79,7 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
 
 	private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = createDefaultClaimTypeConverter();
 
-	private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
+	private final Map<String, JwtDecoder> jwtDecoders;
 
 	private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory();
 
@@ -88,6 +89,19 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
 	private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
 			clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
 
+	public OidcIdTokenDecoderFactory() {
+		this(true);
+	}
+
+	public OidcIdTokenDecoderFactory(boolean withCache) {
+		if (withCache) {
+			this.jwtDecoders = new ConcurrentHashMap<>();
+		}
+		else {
+			this.jwtDecoders = null;
+		}
+	}
+
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
 	 * for an {@link OidcIdToken}.
@@ -135,16 +149,24 @@ public static ClaimTypeConverter createDefaultClaimTypeConverter() {
 	@Override
 	public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
 		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
-		return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> {
-			NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
-			jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
-			Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
-				.apply(clientRegistration);
-			if (claimTypeConverter != null) {
-				jwtDecoder.setClaimSetConverter(claimTypeConverter);
-			}
-			return jwtDecoder;
-		});
+		if (this.jwtDecoders != null) {
+			return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(),
+					(key) -> createFreshDecoder(clientRegistration));
+		}
+		else {
+			return createFreshDecoder(clientRegistration);
+		}
+	}
+
+	private JwtDecoder createFreshDecoder(ClientRegistration clientRegistration) {
+		NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
+		jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
+		Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
+			.apply(clientRegistration);
+		if (claimTypeConverter != null) {
+			jwtDecoder.setClaimSetConverter(claimTypeConverter);
+		}
+		return jwtDecoder;
 	}
 
 	private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java
index 5c066d3bacd..d7ba23b2f1b 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java
@@ -59,6 +59,7 @@
  * @author Rafael Dominguez
  * @author Mark Heckler
  * @author Ubaid ur Rehman
+ * @author Ivan Golovko
  * @since 5.2
  * @see ReactiveJwtDecoderFactory
  * @see ClientRegistration
@@ -80,7 +81,7 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
 	private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter(
 			createDefaultClaimTypeConverters());
 
-	private final Map<String, ReactiveJwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
+	private final Map<String, ReactiveJwtDecoder> jwtDecoders;
 
 	private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory();
 
@@ -90,6 +91,19 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
 	private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
 			clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
 
+	public ReactiveOidcIdTokenDecoderFactory() {
+		this(true);
+	}
+
+	public ReactiveOidcIdTokenDecoderFactory(boolean withCache) {
+		if (withCache) {
+			this.jwtDecoders = new ConcurrentHashMap<>();
+		}
+		else {
+			this.jwtDecoders = null;
+		}
+	}
+
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
 	 * for an {@link OidcIdToken}.
@@ -126,16 +140,24 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
 	@Override
 	public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
 		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
-		return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> {
-			NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
-			jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
-			Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
-				.apply(clientRegistration);
-			if (claimTypeConverter != null) {
-				jwtDecoder.setClaimSetConverter(claimTypeConverter);
-			}
-			return jwtDecoder;
-		});
+		if (this.jwtDecoders != null) {
+			return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(),
+					(key) -> createFreshDecoder(clientRegistration));
+		}
+		else {
+			return createFreshDecoder(clientRegistration);
+		}
+	}
+
+	private ReactiveJwtDecoder createFreshDecoder(ClientRegistration clientRegistration) {
+		NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
+		jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
+		Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
+			.apply(clientRegistration);
+		if (claimTypeConverter != null) {
+			jwtDecoder.setClaimSetConverter(claimTypeConverter);
+		}
+		return jwtDecoder;
 	}
 
 	private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java
index 33663bac650..7a3944d23ce 100644
--- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java
@@ -34,6 +34,7 @@
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -46,6 +47,7 @@
 /**
  * @author Joe Grandja
  * @author Rafael Dominguez
+ * @author Ivan Golovko
  * @since 5.2
  */
 public class OidcIdTokenDecoderFactoryTests {
@@ -177,4 +179,21 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() {
 		verify(customClaimTypeConverterFactory).apply(same(clientRegistration));
 	}
 
+	@Test
+	public void createDecoderTwiceWithCaching() {
+		ClientRegistration clientRegistration = this.registration.build();
+		JwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		JwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		assertThat(decoder1).isSameAs(decoder2);
+	}
+
+	@Test
+	public void createDecoderTwiceWithoutCaching() {
+		this.idTokenDecoderFactory = new OidcIdTokenDecoderFactory(false);
+		ClientRegistration clientRegistration = this.registration.build();
+		JwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		JwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		assertThat(decoder1).isNotSameAs(decoder2);
+	}
+
 }
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java
index 8c5b70ea494..2a4e31aa6dc 100644
--- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java
@@ -34,6 +34,7 @@
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -47,6 +48,7 @@
  * @author Joe Grandja
  * @author Rafael Dominguez
  * @author Ubaid ur Rehman
+ * @author Ivan Golovko
  * @since 5.2
  */
 public class ReactiveOidcIdTokenDecoderFactoryTests {
@@ -177,4 +179,21 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() {
 		verify(customClaimTypeConverterFactory).apply(same(clientRegistration));
 	}
 
+	@Test
+	public void createDecoderTwiceWithCaching() {
+		ClientRegistration clientRegistration = this.registration.build();
+		ReactiveJwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		ReactiveJwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		assertThat(decoder1).isSameAs(decoder2);
+	}
+
+	@Test
+	public void createDecoderTwiceWithoutCaching() {
+		this.idTokenDecoderFactory = new ReactiveOidcIdTokenDecoderFactory(false);
+		ClientRegistration clientRegistration = this.registration.build();
+		ReactiveJwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		ReactiveJwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		assertThat(decoder1).isNotSameAs(decoder2);
+	}
+
 }