diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/SqlClientSSLTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/SqlClientSSLTest.java new file mode 100644 index 0000000000000..ccb4832f0765d --- /dev/null +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/SqlClientSSLTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.client; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.table.gateway.rest.util.SqlGatewayRestEndpointExtension; +import org.apache.flink.table.gateway.service.utils.SqlGatewayServiceExtension; + +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test that {@link SqlClient} works normally when SSL is enabled. */ +class SqlClientSSLTest extends SqlClientTestBase { + @RegisterExtension + @Order(1) + public static final SqlGatewayServiceExtension SQL_GATEWAY_SERVICE_EXTENSION = + new SqlGatewayServiceExtension(Configuration::new); + + @RegisterExtension + @Order(2) + private static final SqlGatewayRestEndpointExtension SQL_GATEWAY_REST_ENDPOINT_EXTENSION = + new SqlGatewayRestEndpointExtension( + SQL_GATEWAY_SERVICE_EXTENSION::getService, SqlClientSSLTest::withSSL); + + private static final String truststorePath = getTestResource("ssl/local127.truststore"); + + private static final String keystorePath = getTestResource("ssl/local127.keystore"); + + @Test + void testEmbeddedMode() throws Exception { + String[] args = new String[] {"embedded"}; + String actual = runSqlClient(args, String.join("\n", "SET;", "QUIT;"), false); + assertThat(actual).contains(SecurityOptions.SSL_REST_ENABLED.key(), "true"); + } + + @Test + void testGatewayMode() throws Exception { + String[] args = + new String[] { + "gateway", + "-e", + InetSocketAddress.createUnresolved( + SQL_GATEWAY_REST_ENDPOINT_EXTENSION.getTargetAddress(), + SQL_GATEWAY_REST_ENDPOINT_EXTENSION.getTargetPort()) + .toString() + }; + String actual = runSqlClient(args, String.join("\n", "SET;", "QUIT;"), false); + assertThat(actual).contains(SecurityOptions.SSL_REST_ENABLED.key(), "true"); + } + + private static void withSSL(Configuration configuration) { + configuration.set(SecurityOptions.SSL_REST_ENABLED, true); + configuration.set(SecurityOptions.SSL_REST_TRUSTSTORE, truststorePath); + configuration.set(SecurityOptions.SSL_REST_TRUSTSTORE_PASSWORD, "password"); + configuration.set(SecurityOptions.SSL_REST_KEYSTORE, keystorePath); + configuration.set(SecurityOptions.SSL_REST_KEYSTORE_PASSWORD, "password"); + configuration.set(SecurityOptions.SSL_REST_KEY_PASSWORD, "password"); + } + + @Override + protected void writeConfigOptionsToConfYaml(Path confYamlPath) throws IOException { + Configuration configuration = new Configuration(); + withSSL(configuration); + Files.write( + confYamlPath, + configuration.toMap().entrySet().stream() + .map(entry -> entry.getKey() + ": " + entry.getValue()) + .collect(Collectors.toList())); + } + + private static String getTestResource(final String fileName) { + final ClassLoader classLoader = ClassLoader.getSystemClassLoader(); + final URL resource = classLoader.getResource(fileName); + if (resource == null) { + throw new IllegalArgumentException( + String.format("Test resource %s does not exist", fileName)); + } + return resource.getFile(); + } +} diff --git a/flink-table/flink-sql-client/src/test/resources/ssl/local127.keystore b/flink-table/flink-sql-client/src/test/resources/ssl/local127.keystore new file mode 100644 index 0000000000000..4992ac4513758 Binary files /dev/null and b/flink-table/flink-sql-client/src/test/resources/ssl/local127.keystore differ diff --git a/flink-table/flink-sql-client/src/test/resources/ssl/local127.truststore b/flink-table/flink-sql-client/src/test/resources/ssl/local127.truststore new file mode 100644 index 0000000000000..df4acf86b0d2c Binary files /dev/null and b/flink-table/flink-sql-client/src/test/resources/ssl/local127.truststore differ diff --git a/flink-table/flink-sql-gateway-api/src/main/java/org/apache/flink/table/gateway/api/endpoint/SqlGatewayEndpointFactory.java b/flink-table/flink-sql-gateway-api/src/main/java/org/apache/flink/table/gateway/api/endpoint/SqlGatewayEndpointFactory.java index e48fe1534517c..f83dc4a90e512 100644 --- a/flink-table/flink-sql-gateway-api/src/main/java/org/apache/flink/table/gateway/api/endpoint/SqlGatewayEndpointFactory.java +++ b/flink-table/flink-sql-gateway-api/src/main/java/org/apache/flink/table/gateway/api/endpoint/SqlGatewayEndpointFactory.java @@ -53,7 +53,7 @@ interface Context { /** * Get a map contains all flink configurations. * - * @return The copy of flink configurations in the form of map, modify this map will not + * @return The copy of flink configurations in the form of map, modifying this map will not * influence the original configuration object. */ Map getFlinkConfigurationOptions(); diff --git a/flink-table/flink-sql-gateway/src/main/java/org/apache/flink/table/gateway/rest/SqlGatewayRestEndpointFactory.java b/flink-table/flink-sql-gateway/src/main/java/org/apache/flink/table/gateway/rest/SqlGatewayRestEndpointFactory.java index e0e0956f81f6a..de2622b80c0f2 100644 --- a/flink-table/flink-sql-gateway/src/main/java/org/apache/flink/table/gateway/rest/SqlGatewayRestEndpointFactory.java +++ b/flink-table/flink-sql-gateway/src/main/java/org/apache/flink/table/gateway/rest/SqlGatewayRestEndpointFactory.java @@ -26,7 +26,6 @@ import org.apache.flink.table.gateway.api.endpoint.SqlGatewayEndpointFactoryUtils; import org.apache.flink.table.gateway.api.utils.SqlGatewayException; -import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -48,7 +47,9 @@ public SqlGatewayEndpoint createSqlGatewayEndpoint(Context context) { SqlGatewayEndpointFactoryUtils.createEndpointFactoryHelper(this, context); // Check that ADDRESS must be set endpointFactoryHelper.validate(); - Configuration config = rebuildRestEndpointOptions(context.getEndpointOptions()); + Configuration config = + rebuildRestEndpointOptions( + context.getEndpointOptions(), context.getFlinkConfigurationOptions()); try { return new SqlGatewayRestEndpoint(config, context.getSqlGatewayService()); } catch (Exception e) { @@ -56,26 +57,25 @@ public SqlGatewayEndpoint createSqlGatewayEndpoint(Context context) { } } - public static Configuration rebuildRestEndpointOptions(Map configMap) { - Map effectiveConfigMap = new HashMap<>(configMap); + public static Configuration rebuildRestEndpointOptions( + Map endpointConfigMap, Map flinkConfigMap) { + flinkConfigMap.put(RestOptions.ADDRESS.key(), endpointConfigMap.get(ADDRESS.key())); - effectiveConfigMap.put(RestOptions.ADDRESS.key(), configMap.get(ADDRESS.key())); - - if (configMap.containsKey(BIND_ADDRESS.key())) { - effectiveConfigMap.put( - RestOptions.BIND_ADDRESS.key(), configMap.get(BIND_ADDRESS.key())); + if (endpointConfigMap.containsKey(BIND_ADDRESS.key())) { + flinkConfigMap.put( + RestOptions.BIND_ADDRESS.key(), endpointConfigMap.get(BIND_ADDRESS.key())); } // we need to override RestOptions.PORT anyway, to use a different default value - effectiveConfigMap.put( + flinkConfigMap.put( RestOptions.PORT.key(), - configMap.getOrDefault(PORT.key(), PORT.defaultValue().toString())); + endpointConfigMap.getOrDefault(PORT.key(), PORT.defaultValue().toString())); - if (configMap.containsKey(BIND_PORT.key())) { - effectiveConfigMap.put(RestOptions.BIND_PORT.key(), configMap.get(BIND_PORT.key())); + if (endpointConfigMap.containsKey(BIND_PORT.key())) { + flinkConfigMap.put(RestOptions.BIND_PORT.key(), endpointConfigMap.get(BIND_PORT.key())); } - return Configuration.fromMap(effectiveConfigMap); + return Configuration.fromMap(flinkConfigMap); } @Override diff --git a/flink-table/flink-sql-gateway/src/test/java/org/apache/flink/table/gateway/rest/util/SqlGatewayRestEndpointTestUtils.java b/flink-table/flink-sql-gateway/src/test/java/org/apache/flink/table/gateway/rest/util/SqlGatewayRestEndpointTestUtils.java index 65daeb4a3a819..2d246aeee5796 100644 --- a/flink-table/flink-sql-gateway/src/test/java/org/apache/flink/table/gateway/rest/util/SqlGatewayRestEndpointTestUtils.java +++ b/flink-table/flink-sql-gateway/src/test/java/org/apache/flink/table/gateway/rest/util/SqlGatewayRestEndpointTestUtils.java @@ -40,7 +40,8 @@ public static Configuration getBaseConfig(Configuration flinkConf) { new SqlGatewayEndpointFactoryUtils.DefaultEndpointFactoryContext( null, flinkConf, getEndpointConfig(flinkConf, IDENTIFIER)); - return rebuildRestEndpointOptions(context.getEndpointOptions()); + return rebuildRestEndpointOptions( + context.getEndpointOptions(), context.getFlinkConfigurationOptions()); } /** Create the configuration generated from config.yaml. */