|
| 1 | +/** |
| 2 | + * Copyright (c) 2014-2025 Snowplow Analytics Ltd. All rights reserved. |
| 3 | + * |
| 4 | + * This program is licensed to you under the Apache License Version 2.0, |
| 5 | + * and you may not use this file except in compliance with the Apache License Version 2.0. |
| 6 | + * You may obtain a copy of the Apache License Version 2.0 at http://www.apache.org/licenses/LICENSE-2.0. |
| 7 | + * |
| 8 | + * Unless required by applicable law or agreed to in writing, |
| 9 | + * software distributed under the Apache License Version 2.0 is distributed on an |
| 10 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | + * See the Apache License Version 2.0 for the specific language governing permissions and limitations there under. |
| 12 | + */ |
| 13 | +package com.snowplowanalytics.stream.loader |
| 14 | +package clients |
| 15 | + |
| 16 | +import java.net.URLEncoder |
| 17 | +import java.nio.charset.StandardCharsets |
| 18 | +import java.security.{InvalidKeyException, MessageDigest, NoSuchAlgorithmException} |
| 19 | +import java.time.LocalDateTime |
| 20 | +import java.time.format.DateTimeFormatter |
| 21 | +import javax.crypto.Mac |
| 22 | +import javax.crypto.spec.SecretKeySpec |
| 23 | + |
| 24 | +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, AWSSessionCredentials} |
| 25 | + |
| 26 | +import scala.collection.immutable.{ListMap, TreeMap} |
| 27 | + |
| 28 | +case object AwsSigner { |
| 29 | + |
| 30 | + def apply( |
| 31 | + credentialsProvider: AWSCredentialsProvider, |
| 32 | + region: String, |
| 33 | + service: String, |
| 34 | + clock: () => LocalDateTime |
| 35 | + ): AwsSigner = new AwsSigner(credentialsProvider, region, service, clock) |
| 36 | + |
| 37 | + def apply( |
| 38 | + awsAccessKeyId: String, |
| 39 | + awsSecretKey: String, |
| 40 | + region: String, |
| 41 | + service: String, |
| 42 | + clock: () => LocalDateTime |
| 43 | + ): AwsSigner = { |
| 44 | + |
| 45 | + val credentialsProvider = new AWSCredentialsProvider { |
| 46 | + override def refresh(): Unit = () |
| 47 | + |
| 48 | + override def getCredentials: AWSCredentials = new AWSCredentials { |
| 49 | + override def getAWSAccessKeyId: String = awsAccessKeyId |
| 50 | + |
| 51 | + override def getAWSSecretKey: String = awsSecretKey |
| 52 | + } |
| 53 | + } |
| 54 | + new AwsSigner(credentialsProvider, region, service, clock) |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +class AwsSigner( |
| 59 | + credentialsProvider: AWSCredentialsProvider, |
| 60 | + region: String, |
| 61 | + service: String, |
| 62 | + clock: () => LocalDateTime |
| 63 | +) { |
| 64 | + |
| 65 | + val HMAC_SHA256 = "HmacSHA256" |
| 66 | + val SLASH = "/" |
| 67 | + val X_AMZ_DATE = "x-amz-date" |
| 68 | + val RETURN = "\n" |
| 69 | + val AWS4_HMAC_SHA256 = "AWS4-HMAC-SHA256" |
| 70 | + val AWS4_REQUEST = "/aws4_request" |
| 71 | + val AWS4_HMAC_SHA256_CREDENTIAL = "AWS4-HMAC-SHA256 Credential=" |
| 72 | + val SIGNED_HEADERS = ", SignedHeaders=" |
| 73 | + val SIGNATURE = ", Signature=" |
| 74 | + val SHA_256 = "SHA-256" |
| 75 | + val AWS4 = "AWS4" |
| 76 | + val AWS_4_REQUEST = "aws4_request" |
| 77 | + val CONNECTION = "connection" |
| 78 | + val CLOSE = ":close" |
| 79 | + val EMPTY = "" |
| 80 | + val ZERO = "0" |
| 81 | + val CONTENT_LENGTH = "Content-Length" |
| 82 | + val AUTHORIZATION = "Authorization" |
| 83 | + val SESSION_TOKEN = "x-amz-security-token" |
| 84 | + val DATE = "date" |
| 85 | + val DATE_FORMATTER = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'") |
| 86 | + val BASIC_DATE_FORMATTER = DateTimeFormatter.BASIC_ISO_DATE |
| 87 | + |
| 88 | + def getSignedHeaders( |
| 89 | + uri: String, |
| 90 | + method: String, |
| 91 | + queryParams: Map[String, String], |
| 92 | + headers: Map[String, String], |
| 93 | + payload: Option[Array[Byte]] |
| 94 | + ): Map[String, String] = { |
| 95 | + |
| 96 | + def queryParamsString(queryParams: Map[String, String]) = { |
| 97 | + // sort params by key in ascending order |
| 98 | + val orderedParams = ListMap(queryParams.toSeq.sortWith(_._1 < _._1): _*) |
| 99 | + |
| 100 | + // encode params |
| 101 | + orderedParams |
| 102 | + .map { case (key, value) => |
| 103 | + key + "=" + URLEncoder.encode(value, StandardCharsets.UTF_8.toString) |
| 104 | + } |
| 105 | + .mkString("&") |
| 106 | + } |
| 107 | + |
| 108 | + def sign(stringToSign: String, now: LocalDateTime, credentials: AWSCredentials): String = { |
| 109 | + def hmacSHA256(data: String, key: Array[Byte]): Array[Byte] = { |
| 110 | + try { |
| 111 | + val mac: Mac = Mac.getInstance(HMAC_SHA256) |
| 112 | + mac.init(new SecretKeySpec(key, HMAC_SHA256)) |
| 113 | + mac.doFinal(data.getBytes(StandardCharsets.UTF_8)) |
| 114 | + } catch { |
| 115 | + case e: NoSuchAlgorithmException => throw e |
| 116 | + case i: InvalidKeyException => throw i |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + def getSignatureKey(now: LocalDateTime, credentials: AWSCredentials): Array[Byte] = { |
| 121 | + val kSecret: Array[Byte] = |
| 122 | + (AWS4 + credentials.getAWSSecretKey).getBytes(StandardCharsets.UTF_8) |
| 123 | + val kDate: Array[Byte] = hmacSHA256(now.format(BASIC_DATE_FORMATTER), kSecret) |
| 124 | + val kRegion: Array[Byte] = hmacSHA256(region, kDate) |
| 125 | + val kService: Array[Byte] = hmacSHA256(service, kRegion) |
| 126 | + hmacSHA256(AWS_4_REQUEST, kService) |
| 127 | + } |
| 128 | + |
| 129 | + toBase16(hmacSHA256(stringToSign, getSignatureKey(now, credentials))) |
| 130 | + } |
| 131 | + |
| 132 | + def headerAsString(header: (String, Object), method: String): String = |
| 133 | + if (header._1.equalsIgnoreCase(CONNECTION)) { |
| 134 | + CONNECTION + CLOSE |
| 135 | + } else if ( |
| 136 | + header._1.equalsIgnoreCase(CONTENT_LENGTH) && header._2.equals(ZERO) && !method |
| 137 | + .equalsIgnoreCase("POST") |
| 138 | + ) { |
| 139 | + header._1.toLowerCase + ':' |
| 140 | + } else { |
| 141 | + header._1.toLowerCase + ':' + header._2 |
| 142 | + } |
| 143 | + |
| 144 | + def getCredentialScope(now: LocalDateTime): String = |
| 145 | + now.format(BASIC_DATE_FORMATTER) + SLASH + region + SLASH + service + AWS4_REQUEST |
| 146 | + |
| 147 | + def hash(payload: Array[Byte]): Array[Byte] = |
| 148 | + try { |
| 149 | + val md: MessageDigest = MessageDigest.getInstance(SHA_256) |
| 150 | + md.update(payload) |
| 151 | + md.digest |
| 152 | + } catch { |
| 153 | + case n: NoSuchAlgorithmException => throw n |
| 154 | + } |
| 155 | + |
| 156 | + def toBase16(data: Array[Byte]): String = data.map("%02x" format _).mkString |
| 157 | + |
| 158 | + def createStringToSign(canonicalRequest: String, now: LocalDateTime): String = |
| 159 | + AWS4_HMAC_SHA256 + RETURN + |
| 160 | + now.format(DATE_FORMATTER) + RETURN + |
| 161 | + getCredentialScope(now) + RETURN + |
| 162 | + toBase16(hash(canonicalRequest.getBytes(StandardCharsets.UTF_8))) |
| 163 | + |
| 164 | + // evaluate current time from the provided clock |
| 165 | + val now: LocalDateTime = clock.apply() |
| 166 | + |
| 167 | + // signing credentials |
| 168 | + val credentials: AWSCredentials = credentialsProvider.getCredentials |
| 169 | + |
| 170 | + var result = TreeMap[String, String]()(Ordering.by(_.toLowerCase)) |
| 171 | + for ((key, value) <- headers) result += key -> value |
| 172 | + |
| 173 | + if (!result.contains(DATE)) { |
| 174 | + result += (X_AMZ_DATE -> now.format(DATE_FORMATTER)) |
| 175 | + } |
| 176 | + |
| 177 | + credentials match { |
| 178 | + case asc: AWSSessionCredentials => result += (SESSION_TOKEN -> asc.getSessionToken) |
| 179 | + case _ => // do nothing |
| 180 | + } |
| 181 | + |
| 182 | + val headersString: String = result.map(pair => headerAsString(pair, method) + RETURN).mkString |
| 183 | + val signedHeaders: List[String] = result.map(pair => pair._1.toLowerCase).toList |
| 184 | + |
| 185 | + val signedHeaderKeys = signedHeaders.mkString(";") |
| 186 | + val canonicalRequest = |
| 187 | + method + RETURN + |
| 188 | + uri + RETURN + |
| 189 | + queryParamsString(queryParams) + RETURN + |
| 190 | + headersString + RETURN + |
| 191 | + signedHeaderKeys + RETURN + |
| 192 | + toBase16(hash(payload.getOrElse(EMPTY.getBytes(StandardCharsets.UTF_8)))) |
| 193 | + |
| 194 | + val stringToSign = createStringToSign(canonicalRequest, now) |
| 195 | + val signature = sign(stringToSign, now, credentials) |
| 196 | + val authorizationHeader = AWS4_HMAC_SHA256_CREDENTIAL + |
| 197 | + credentials.getAWSAccessKeyId + SLASH + getCredentialScope(now) + |
| 198 | + SIGNED_HEADERS + signedHeaderKeys + |
| 199 | + SIGNATURE + signature |
| 200 | + |
| 201 | + result += (AUTHORIZATION -> authorizationHeader) |
| 202 | + |
| 203 | + result |
| 204 | + } |
| 205 | +} |
0 commit comments