diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java index 00698846a85a3..d01333860e878 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java @@ -18,9 +18,9 @@ public class ArrowFlightConfig { private String server; - private Boolean verifyServer; + private boolean verifyServer = true; private String flightServerSSLCertificate; - private Boolean arrowFlightServerSslEnabled; + private boolean arrowFlightServerSslEnabled; private Integer arrowFlightPort; public String getFlightServerName() @@ -35,13 +35,13 @@ public ArrowFlightConfig setFlightServerName(String server) return this; } - public Boolean getVerifyServer() + public boolean getVerifyServer() { return verifyServer; } @Config("arrow-flight.server.verify") - public ArrowFlightConfig setVerifyServer(Boolean verifyServer) + public ArrowFlightConfig setVerifyServer(boolean verifyServer) { this.verifyServer = verifyServer; return this; @@ -71,13 +71,13 @@ public ArrowFlightConfig setFlightServerSSLCertificate(String flightServerSSLCer return this; } - public Boolean getArrowFlightServerSslEnabled() + public boolean getArrowFlightServerSslEnabled() { return arrowFlightServerSslEnabled; } @Config("arrow-flight.server-ssl-enabled") - public ArrowFlightConfig setArrowFlightServerSslEnabled(Boolean arrowFlightServerSslEnabled) + public ArrowFlightConfig setArrowFlightServerSslEnabled(boolean arrowFlightServerSslEnabled) { this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled; return this; diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java index 573265ef637d4..e8fde1ec39521 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java @@ -53,11 +53,11 @@ public BaseArrowFlightClientHandler(BufferAllocator allocator, ArrowFlightConfig protected FlightClient createFlightClient() { Location location; - if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) { - location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort()); + if (config.getArrowFlightServerSslEnabled()) { + location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort()); } else { - location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort()); + location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort()); } return createFlightClient(location); } @@ -67,10 +67,8 @@ protected FlightClient createFlightClient(Location location) try { Optional trustedCertificate = Optional.empty(); FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); - if (config.getVerifyServer() != null && !config.getVerifyServer()) { - flightClientBuilder.verifyServer(false); - } - else if (config.getFlightServerSSLCertificate() != null) { + flightClientBuilder.verifyServer(config.getVerifyServer()); + if (config.getFlightServerSSLCertificate() != null) { trustedCertificate = Optional.of(newInputStream(Paths.get(config.getFlightServerSSLCertificate()))); flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls(); } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java index f105359f0ae73..384b95b06ad56 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java @@ -27,6 +27,7 @@ import java.io.File; import java.util.Map; +import java.util.Optional; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -53,11 +54,15 @@ private static DistributedQueryRunner createQueryRunner( throws Exception { Session session = testSessionBuilder() - .setCatalog("arrow") + .setCatalog("arrowflight") .setSchema("tpch") .build(); - DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).setExtraProperties(extraProperties).build(); + DistributedQueryRunner.Builder queryRunnerBuilder = DistributedQueryRunner.builder(session); + Optional workerCount = getProperty("WORKER_COUNT").map(Integer::parseInt); + workerCount.ifPresent(queryRunnerBuilder::setNodeCount); + + DistributedQueryRunner queryRunner = queryRunnerBuilder.setExtraProperties(extraProperties).build(); try { queryRunner.installPlugin(new TestingArrowFlightPlugin()); @@ -66,10 +71,9 @@ private static DistributedQueryRunner createQueryRunner( .putAll(catalogProperties) .put("arrow-flight.server", "localhost") .put("arrow-flight.server-ssl-enabled", "true") - .put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt") - .put("arrow-flight.server.verify", "true"); + .put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt"); - queryRunner.createCatalog("arrow", "arrow", properties.build()); + queryRunner.createCatalog("arrowflight", "arrow-flight", properties.build()); return queryRunner; } @@ -78,6 +82,19 @@ private static DistributedQueryRunner createQueryRunner( } } + private static Optional getProperty(String name) + { + String systemPropertyValue = System.getProperty(name); + if (systemPropertyValue != null) { + return Optional.of(systemPropertyValue); + } + String environmentVariableValue = System.getenv(name); + if (environmentVariableValue != null) { + return Optional.of(environmentVariableValue); + } + return Optional.empty(); + } + public static void main(String[] args) throws Exception { diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java index 8263defd59585..d9e057598b207 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java @@ -410,7 +410,7 @@ private static MapType createMapType(Type keyType, Type valueType) private static FlightClient createFlightClient(BufferAllocator allocator) throws IOException { InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/server.crt"))); - return FlightClient.builder(allocator, getServerLocation()).verifyServer(true).useTls().trustedCertificates(trustedCertificate).build(); + return FlightClient.builder(allocator, getServerLocation()).useTls().trustedCertificates(trustedCertificate).build(); } private void addTableToServer(FlightClient client, VectorSchemaRoot root, String tableName) diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java index 196098c3151c7..830b5b04b3b5e 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java @@ -21,6 +21,6 @@ public class TestingArrowFlightPlugin { public TestingArrowFlightPlugin() { - super("arrow", new TestingArrowModule(), new JsonModule()); + super("arrow-flight", new TestingArrowModule(), new JsonModule()); } }