|
1 | 1 | package fi.hsl.common.redis;
|
2 | 2 |
|
| 3 | +import com.azure.core.credential.AccessToken; |
| 4 | +import com.azure.core.credential.TokenCredential; |
| 5 | +import com.azure.core.credential.TokenRequestContext; |
| 6 | +import com.azure.core.util.CoreUtils; |
| 7 | +import com.google.gson.JsonObject; |
| 8 | +import com.google.gson.JsonParser; |
3 | 9 | import fi.hsl.common.pulsar.PulsarApplicationContext;
|
4 | 10 | import fi.hsl.common.transitdata.TransitdataProperties;
|
5 | 11 | import org.jetbrains.annotations.NotNull;
|
|
10 | 16 | import redis.clients.jedis.params.ScanParams;
|
11 | 17 | import redis.clients.jedis.resps.ScanResult;
|
12 | 18 |
|
| 19 | +import java.nio.charset.StandardCharsets; |
| 20 | +import java.time.Duration; |
13 | 21 | import java.time.OffsetDateTime;
|
14 | 22 | import java.time.format.DateTimeFormatter;
|
15 | 23 | import java.util.*;
|
| 24 | +import java.util.concurrent.ThreadLocalRandom; |
16 | 25 |
|
17 | 26 | public class RedisUtils {
|
18 | 27 | private static final Logger log = LoggerFactory.getLogger(RedisUtils.class);
|
@@ -227,4 +236,105 @@ public boolean checkResponse(@Nullable final String response) {
|
227 | 236 | public boolean checkResponse(@Nullable final Long response) {
|
228 | 237 | return response != null && response == 1;
|
229 | 238 | }
|
| 239 | + |
| 240 | + // Azure Cache for Redis helper code |
| 241 | + public static Jedis createJedisClient(String cacheHostname, int port, String username, AccessToken accessToken, boolean useSsl) { |
| 242 | + return new Jedis(cacheHostname, port, DefaultJedisClientConfig.builder() |
| 243 | + .password(accessToken.getToken()) |
| 244 | + .user(username) |
| 245 | + .ssl(useSsl) |
| 246 | + .build()); |
| 247 | + } |
| 248 | + |
| 249 | + public static String extractUsernameFromToken(String token) { |
| 250 | + String[] parts = token.split("\\."); |
| 251 | + String base64 = parts[1]; |
| 252 | + |
| 253 | + switch (base64.length() % 4) { |
| 254 | + case 2: |
| 255 | + base64 += "=="; |
| 256 | + break; |
| 257 | + case 3: |
| 258 | + base64 += "="; |
| 259 | + break; |
| 260 | + } |
| 261 | + |
| 262 | + byte[] jsonBytes = Base64.getDecoder().decode(base64); |
| 263 | + String json = new String(jsonBytes, StandardCharsets.UTF_8); |
| 264 | + JsonObject jwt = JsonParser.parseString(json).getAsJsonObject(); |
| 265 | + |
| 266 | + return jwt.get("oid").getAsString(); |
| 267 | + } |
| 268 | + |
| 269 | + /** |
| 270 | + * The token cache to store and proactively refresh the access token. |
| 271 | + */ |
| 272 | + public static class TokenRefreshCache { |
| 273 | + private final TokenCredential tokenCredential; |
| 274 | + private final TokenRequestContext tokenRequestContext; |
| 275 | + private final Timer timer; |
| 276 | + private volatile AccessToken accessToken; |
| 277 | + private final Duration maxRefreshOffset = Duration.ofMinutes(5); |
| 278 | + private final Duration baseRefreshOffset = Duration.ofMinutes(2); |
| 279 | + private Jedis jedisInstanceToAuthenticate; |
| 280 | + private String username; |
| 281 | + |
| 282 | + /** |
| 283 | + * Creates an instance of TokenRefreshCache |
| 284 | + * @param tokenCredential the token credential to be used for authentication. |
| 285 | + * @param tokenRequestContext the token request context to be used for authentication. |
| 286 | + */ |
| 287 | + public TokenRefreshCache(TokenCredential tokenCredential, TokenRequestContext tokenRequestContext) { |
| 288 | + this.tokenCredential = tokenCredential; |
| 289 | + this.tokenRequestContext = tokenRequestContext; |
| 290 | + this.timer = new Timer(); |
| 291 | + } |
| 292 | + |
| 293 | + /** |
| 294 | + * Gets the cached access token. |
| 295 | + * @return the AccessToken |
| 296 | + */ |
| 297 | + public AccessToken getAccessToken() { |
| 298 | + if (accessToken != null) { |
| 299 | + return accessToken; |
| 300 | + } else { |
| 301 | + TokenRefreshTask tokenRefreshTask = new TokenRefreshTask(); |
| 302 | + accessToken = tokenCredential.getToken(tokenRequestContext).block(); |
| 303 | + timer.schedule(tokenRefreshTask, getTokenRefreshDelay()); |
| 304 | + return accessToken; |
| 305 | + } |
| 306 | + } |
| 307 | + |
| 308 | + private class TokenRefreshTask extends TimerTask { |
| 309 | + // Add your task here |
| 310 | + public void run() { |
| 311 | + accessToken = tokenCredential.getToken(tokenRequestContext).block(); |
| 312 | + username = extractUsernameFromToken(accessToken.getToken()); |
| 313 | + System.out.println("Refreshed Token with Expiry: " + accessToken.getExpiresAt().toEpochSecond()); |
| 314 | + |
| 315 | + if (jedisInstanceToAuthenticate != null && !CoreUtils.isNullOrEmpty(username)) { |
| 316 | + jedisInstanceToAuthenticate.auth(username, accessToken.getToken()); |
| 317 | + System.out.println("Refreshed Jedis Connection with fresh access token, token expires at : " |
| 318 | + + accessToken.getExpiresAt().toEpochSecond()); |
| 319 | + } |
| 320 | + timer.schedule(new TokenRefreshTask(), getTokenRefreshDelay()); |
| 321 | + } |
| 322 | + } |
| 323 | + |
| 324 | + private long getTokenRefreshDelay() { |
| 325 | + return ((accessToken.getExpiresAt() |
| 326 | + .minusSeconds(ThreadLocalRandom.current().nextLong(baseRefreshOffset.getSeconds(), maxRefreshOffset.getSeconds())) |
| 327 | + .toEpochSecond() - OffsetDateTime.now().toEpochSecond()) * 1000); |
| 328 | + } |
| 329 | + |
| 330 | + /** |
| 331 | + * Sets the Jedis to proactively authenticate before token expiry. |
| 332 | + * @param jedisInstanceToAuthenticate the instance to authenticate |
| 333 | + * @return the updated instance |
| 334 | + */ |
| 335 | + public TokenRefreshCache setJedisInstanceToAuthenticate(Jedis jedisInstanceToAuthenticate) { |
| 336 | + this.jedisInstanceToAuthenticate = jedisInstanceToAuthenticate; |
| 337 | + return this; |
| 338 | + } |
| 339 | + } |
230 | 340 | }
|
0 commit comments