From 4775ae505c0381d38d61b151df1e519e86b1b564 Mon Sep 17 00:00:00 2001 From: Deepak Majeti Date: Wed, 29 May 2024 13:51:46 +0530 Subject: [PATCH 1/2] [native] Pass extraCredentials to connector session properties --- .../presto_cpp/main/QueryContextManager.cpp | 21 +++++++++++------- .../presto_cpp/main/QueryContextManager.h | 2 +- .../presto_cpp/main/TaskResource.cpp | 4 ++-- .../main/tests/QueryContextManagerTest.cpp | 22 ++++++++++++++----- .../main/tests/ServerOperationTest.cpp | 2 +- .../presto_cpp/main/tests/TaskManagerTest.cpp | 2 +- 6 files changed, 34 insertions(+), 19 deletions(-) diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp index 0dbed965fc742..043f354e783fb 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp @@ -60,14 +60,17 @@ void updateFromSystemConfigs( } std::unordered_map> -toConnectorConfigs(const protocol::SessionRepresentation& session) { +toConnectorConfigs(const protocol::TaskUpdateRequest& taskUpdateRequest) { std::unordered_map> connectorConfigs; - for (const auto& entry : session.catalogProperties) { - connectorConfigs.insert( - {entry.first, - std::unordered_map( - entry.second.begin(), entry.second.end())}); + for (const auto& entry : taskUpdateRequest.session.catalogProperties) { + auto sessionProperties = std::unordered_map( + entry.second.begin(), entry.second.end()); + sessionProperties.insert( + taskUpdateRequest.extraCredentials.begin(), + taskUpdateRequest.extraCredentials.end()); + sessionProperties.insert({"user", taskUpdateRequest.session.user}); + connectorConfigs.insert({entry.first, sessionProperties}); } return connectorConfigs; @@ -106,9 +109,11 @@ QueryContextManager::QueryContextManager( std::shared_ptr QueryContextManager::findOrCreateQueryCtx( const protocol::TaskId& taskId, - const protocol::SessionRepresentation& session) { + const protocol::TaskUpdateRequest& taskUpdateRequest) { return findOrCreateQueryCtx( - taskId, toVeloxConfigs(session), toConnectorConfigs(session)); + taskId, + toVeloxConfigs(taskUpdateRequest.session), + toConnectorConfigs(taskUpdateRequest)); } std::shared_ptr QueryContextManager::findOrCreateQueryCtx( diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.h b/presto-native-execution/presto_cpp/main/QueryContextManager.h index 16d97f551ee3d..f8b1a1836ce55 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.h +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.h @@ -107,7 +107,7 @@ class QueryContextManager { std::shared_ptr findOrCreateQueryCtx( const protocol::TaskId& taskId, - const protocol::SessionRepresentation& session); + const protocol::TaskUpdateRequest& taskUpdateRequest); /// Calls the given functor for every present query context. void visitAllContexts(std::functionfindOrCreateQueryCtx( - taskId, updateRequest.session); + taskId, updateRequest); VeloxBatchQueryPlanConverter converter( shuffleName, @@ -340,7 +340,7 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTask( queryCtx = taskManager_.getQueryContextManager()->findOrCreateQueryCtx( - taskId, updateRequest.session); + taskId, updateRequest); VeloxInteractiveQueryPlanConverter converter(queryCtx.get(), pool_); planFragment = converter.toVeloxQueryPlan( diff --git a/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp index 0f258c5896393..637e21e54a3aa 100644 --- a/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp @@ -61,8 +61,10 @@ TEST_F(QueryContextManagerTest, nativeSessionProperties) { {"native_expression_max_array_size_in_reduce", "99999"}, {"native_expression_max_compiled_regexes", "54321"}, }}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); EXPECT_EQ(queryCtx->queryConfig().maxSpillLevel(), 2); EXPECT_EQ(queryCtx->queryConfig().spillCompressionKind(), "NONE"); EXPECT_FALSE(queryCtx->queryConfig().joinSpillEnabled()); @@ -83,8 +85,10 @@ TEST_F(QueryContextManagerTest, defaultSessionProperties) { protocol::TaskId taskId = "scan.0.0.1.0"; protocol::SessionRepresentation session{.systemProperties = {}}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); const auto& queryConfig = queryCtx->queryConfig(); EXPECT_EQ(queryConfig.maxSpillLevel(), defaultQC->maxSpillLevel()); EXPECT_EQ( @@ -102,9 +106,11 @@ TEST_F(QueryContextManagerTest, overrdingSessionProperties) { const auto& systemConfig = SystemConfig::instance(); { protocol::SessionRepresentation session{.systemProperties = {}}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); EXPECT_EQ( queryCtx->queryConfig().queryMaxMemoryPerNode(), systemConfig->queryMaxMemoryPerNode()); @@ -117,9 +123,11 @@ TEST_F(QueryContextManagerTest, overrdingSessionProperties) { .systemProperties = { {"query_max_memory_per_node", "1GB"}, {"spill_file_create_config", "encoding:replica_2"}}}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); EXPECT_EQ( queryCtx->queryConfig().queryMaxMemoryPerNode(), 1UL * 1024 * 1024 * 1024); @@ -131,6 +139,8 @@ TEST_F(QueryContextManagerTest, overrdingSessionProperties) { TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { const protocol::TaskId fakeTaskId = "scan.0.0.1.0"; const protocol::SessionRepresentation fakeSession{.systemProperties = {}}; + protocol::TaskUpdateRequest fakeUpdateRequest; + fakeUpdateRequest.session = fakeSession; auto* queryCtxManager = taskManager_->getQueryContextManager(); struct { bool hasPendingReference; @@ -154,7 +164,7 @@ TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { queryCtxManager->testingClearCache(); auto queryCtx = - queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeSession); + queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeUpdateRequest); const auto poolName = queryCtx->pool()->name(); ASSERT_THAT(poolName, testing::HasSubstr("scan_")); if (!testData.hasPendingReference) { @@ -164,7 +174,7 @@ TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { queryCtxManager->testingClearCache(); } auto newQueryCtx = - queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeSession); + queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeUpdateRequest); const auto newPoolName = newQueryCtx->pool()->name(); ASSERT_THAT(newPoolName, testing::HasSubstr("scan_")); if (testData.expectedNewPoolName) { diff --git a/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp b/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp index a338882684ebf..05fefe3aa559c 100644 --- a/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp @@ -177,7 +177,7 @@ TEST_F(ServerOperationTest, taskEndpoint) { planFragment, true, taskManager->getQueryContextManager()->findOrCreateQueryCtx( - taskId, updateRequest.session), + taskId, updateRequest), 0); }; std::vector taskIds = {"task_0.0.0.0.0", "task_1.0.0.0.0"}; diff --git a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp index 641f5857b97c1..9d05cf2212194 100644 --- a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp @@ -672,7 +672,7 @@ class TaskManagerTest : public exec::test::OperatorTestBase, bool summarize = true) { auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, updateRequest.session); + taskId, updateRequest); return taskManager_->createOrUpdateTask( taskId, updateRequest, planFragment, summarize, std::move(queryCtx), 0); } From 9aa8360cd0a52a639766093b4a2661d493621064 Mon Sep 17 00:00:00 2001 From: Rijin-N Date: Thu, 14 Nov 2024 10:17:47 +0530 Subject: [PATCH 2/2] [native] Add Arrow Flight Connector The native Arrow Flight connector can be used to connect to any Arrow Flight enabled Data Source. The metadata layer is handled by the Presto coordinator and does not need to be re-implemented in C++. Any Java connector that inherits from `presto-base-arrow-flight` can use this connector as it's counterpart for the Prestissimo layer. Different Arrow-Flight enabled data sources can differ in authentication styles. A plugin-style interface is provided to handle such cases with custom authentication code by extending `arrow_flight::auth::Authenticator`. RFC: https://github.com/prestodb/rfcs/blob/main/RFC-0004-arrow-flight-connector.md#prestissimo-implementation Co-authored-by: Ashwin Kumar Co-authored-by: Rijin-N Co-authored-by: Nischay Yadav --- presto-native-execution/CMakeLists.txt | 2 + presto-native-execution/README.md | 9 + .../presto_cpp/main/CMakeLists.txt | 3 +- .../presto_cpp/main/PrestoServer.cpp | 54 +-- .../presto_cpp/main/PrestoServer.h | 2 - .../presto_cpp/main/connectors/CMakeLists.txt | 20 ++ .../PrestoToVeloxConnector.cpp | 6 +- .../PrestoToVeloxConnector.h | 4 +- .../main/connectors/Registration.cpp | 95 +++++ .../presto_cpp/main/connectors/Registration.h | 20 ++ .../main/{ => connectors}/SystemConnector.cpp | 2 +- .../main/{ => connectors}/SystemConnector.h | 5 +- .../main/{ => connectors}/SystemSplit.h | 0 .../arrow_flight/ArrowFlightConfig.cpp | 45 +++ .../arrow_flight/ArrowFlightConfig.h | 57 +++ .../arrow_flight/ArrowFlightConnector.cpp | 188 ++++++++++ .../arrow_flight/ArrowFlightConnector.h | 194 +++++++++++ .../ArrowPrestoToVeloxConnector.cpp | 63 ++++ .../ArrowPrestoToVeloxConnector.h | 47 +++ .../connectors/arrow_flight/CMakeLists.txt | 44 +++ .../main/connectors/arrow_flight/Macros.h | 50 +++ .../arrow_flight/auth/Authenticator.cpp | 48 +++ .../arrow_flight/auth/Authenticator.h | 85 +++++ .../arrow_flight/auth/CMakeLists.txt | 15 + .../tests/ArrowFlightConfigTest.cpp | 48 +++ .../tests/ArrowFlightConnectorAuthTest.cpp | 236 +++++++++++++ .../ArrowFlightConnectorDataTypeTest.cpp | 328 ++++++++++++++++++ .../tests/ArrowFlightConnectorTest.cpp | 184 ++++++++++ .../tests/ArrowFlightConnectorTlsTest.cpp | 128 +++++++ .../arrow_flight/tests/CMakeLists.txt | 45 +++ .../tests/TestingArrowFlightServerTest.cpp | 83 +++++ .../arrow_flight/tests/data/README.md | 7 + .../tests/data/generate_tls_certs.sh | 40 +++ .../arrow_flight/tests/data/tls_certs/ca.crt | 22 ++ .../tests/data/tls_certs/server.crt | 22 ++ .../tests/data/tls_certs/server.key | 28 ++ .../utils/ArrowFlightConnectorTestBase.cpp | 83 +++++ .../utils/ArrowFlightConnectorTestBase.h | 89 +++++ .../tests/utils/ArrowFlightPlanBuilder.cpp | 45 +++ .../tests/utils/ArrowFlightPlanBuilder.h | 35 ++ .../arrow_flight/tests/utils/CMakeLists.txt | 19 + .../tests/utils/TestingArrowFlightServer.cpp | 34 ++ .../tests/utils/TestingArrowFlightServer.h | 48 +++ .../arrow_flight/tests/utils/Utils.cpp | 90 +++++ .../arrow_flight/tests/utils/Utils.h | 55 +++ .../main/operators/tests/CMakeLists.txt | 1 + .../presto_cpp/main/tests/TaskManagerTest.cpp | 2 +- .../presto_cpp/main/types/CMakeLists.txt | 10 +- .../main/types/PrestoToVeloxQueryPlan.cpp | 2 +- .../main/types/PrestoToVeloxSplit.cpp | 2 +- .../main/types/tests/CMakeLists.txt | 4 + .../main/types/tests/PlanConverterTest.cpp | 2 +- .../tests/PrestoToVeloxConnectorTest.cpp | 2 +- .../types/tests/PrestoToVeloxSplitTest.cpp | 2 +- .../presto_cpp/presto_protocol/Makefile | 9 + .../ArrowFlightConnectorProtocol.h | 29 ++ .../presto_protocol-json-cpp.mustache | 150 ++++++++ .../presto_protocol-json-hpp.mustache | 76 ++++ .../presto_protocol_arrow_flight.cpp | 215 ++++++++++++ .../presto_protocol_arrow_flight.h | 82 +++++ .../presto_protocol_arrow_flight.yml | 40 +++ .../special/ArrowTransactionHandle.cpp.inc | 30 ++ .../special/ArrowTransactionHandle.hpp.inc | 28 ++ .../core/presto_protocol_core.cpp | 1 + .../core/presto_protocol_core.yml | 4 + .../ConnectorTransactionHandle.cpp.inc | 1 + .../presto_protocol/presto_protocol.cpp | 1 + .../presto_protocol/presto_protocol.h | 1 + .../presto_protocol/presto_protocol.yml | 8 + .../scripts/setup-adapters.sh | 68 ++++ 70 files changed, 3426 insertions(+), 71 deletions(-) create mode 100644 presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt rename presto-native-execution/presto_cpp/main/{types => connectors}/PrestoToVeloxConnector.cpp (99%) rename presto-native-execution/presto_cpp/main/{types => connectors}/PrestoToVeloxConnector.h (99%) create mode 100644 presto-native-execution/presto_cpp/main/connectors/Registration.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/Registration.h rename presto-native-execution/presto_cpp/main/{ => connectors}/SystemConnector.cpp (99%) rename presto-native-execution/presto_cpp/main/{ => connectors}/SystemConnector.h (98%) rename presto-native-execution/presto_cpp/main/{ => connectors}/SystemSplit.h (100%) create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestingArrowFlightServerTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md create mode 100755 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index d5001dde70a74..6ac789b73e3a1 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -63,6 +63,8 @@ option(PRESTO_ENABLE_TESTING "Enable tests" ON) option(PRESTO_ENABLE_JWT "Enable JWT (JSON Web Token) authentication" OFF) +option(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR "Enable Arrow Flight connector" OFF) + # Set all Velox options below add_compile_definitions(FOLLY_HAVE_INT128_T=1) diff --git a/presto-native-execution/README.md b/presto-native-execution/README.md index cccebfcfb8d03..1976be406c2e0 100644 --- a/presto-native-execution/README.md +++ b/presto-native-execution/README.md @@ -115,6 +115,15 @@ follow these steps: * For development, use `make debug` to build a non-optimized debug version. * Use `make unittest` to build and run tests. +#### Arrow Flight Connector +To enable Arrow Flight connector support, add to the extra cmake flags: +`EXTRA_CMAKE_FLAGS = -DPRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR=ON` + +The Arrow Flight connector requires the Arrow Flight library. You can install this dependency +by running the following script from the `presto/presto-native-execution` directory: + +`./scripts/setup-adapters.sh arrow_flight` + ### Makefile Targets A reminder of the available Makefile targets can be obtained using `make help` ``` diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index c06e00edf834c..9930a41ad4670 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) add_subdirectory(thrift) +add_subdirectory(connectors) add_library( presto_server_lib @@ -29,7 +30,6 @@ add_library( QueryContextManager.cpp ServerOperation.cpp SignalHandler.cpp - SystemConnector.cpp SessionProperties.cpp TaskManager.cpp TaskResource.cpp @@ -48,6 +48,7 @@ target_link_libraries( presto_common presto_exception presto_function_metadata + presto_connector presto_http presto_operators presto_velox_conversion diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index b00422315e3ac..e4fb40863b5bb 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -21,11 +21,12 @@ #include "presto_cpp/main/PeriodicMemoryChecker.h" #include "presto_cpp/main/PeriodicTaskManager.h" #include "presto_cpp/main/SignalHandler.h" -#include "presto_cpp/main/SystemConnector.h" #include "presto_cpp/main/TaskResource.h" #include "presto_cpp/main/common/ConfigReader.h" #include "presto_cpp/main/common/Counters.h" #include "presto_cpp/main/common/Utils.h" +#include "presto_cpp/main/connectors/Registration.h" +#include "presto_cpp/main/connectors/SystemConnector.h" #include "presto_cpp/main/http/HttpConstants.h" #include "presto_cpp/main/http/filters/AccessLogFilter.h" #include "presto_cpp/main/http/filters/HttpEndpointLatencyFilter.h" @@ -48,13 +49,11 @@ #include "velox/common/memory/MmapAllocator.h" #include "velox/common/memory/SharedArbitrator.h" #include "velox/connectors/Connector.h" -#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" -#include "velox/connectors/tpch/TpchConnector.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" #include "velox/dwio/orc/reader/OrcReader.h" @@ -88,7 +87,6 @@ constexpr char const* kHttps = "https"; constexpr char const* kTaskUriFormat = "{}://{}:{}"; // protocol, address and port constexpr char const* kConnectorName = "connector.name"; -constexpr char const* kHiveHadoop2ConnectorName = "hive-hadoop2"; protocol::NodeState convertNodeState(presto::NodeState nodeState) { switch (nodeState) { @@ -254,33 +252,9 @@ void PrestoServer::run() { registerMemoryArbitrators(); registerShuffleInterfaceFactories(); registerCustomOperators(); - registerConnectorFactories(); - - // Register Velox connector factory for iceberg. - // The iceberg catalog is handled by the hive connector factory. - velox::connector::registerConnectorFactory( - std::make_shared( - "iceberg")); - - registerPrestoToVeloxConnector( - std::make_unique("hive")); - registerPrestoToVeloxConnector( - std::make_unique("hive-hadoop2")); - registerPrestoToVeloxConnector( - std::make_unique("iceberg")); - registerPrestoToVeloxConnector( - std::make_unique("tpch")); - // Presto server uses system catalog or system schema in other catalogs - // in different places in the code. All these resolve to the SystemConnector. - // Depending on where the operator or column is used, different prefixes can - // be used in the naming. So the protocol class is mapped - // to all the different prefixes for System tables/columns. - registerPrestoToVeloxConnector( - std::make_unique("$system")); - registerPrestoToVeloxConnector( - std::make_unique("system")); - registerPrestoToVeloxConnector( - std::make_unique("$system@system")); + + // Register Presto connector factories and connectors + presto::registerConnectors(); velox::exec::OutputBufferManager::initialize({}); initializeVeloxMemory(); @@ -1165,24 +1139,6 @@ PrestoServer::getAdditionalHttpServerFilters() { return filters; } -void PrestoServer::registerConnectorFactories() { - // These checks for connector factories can be removed after we remove the - // registrations from the Velox library. - if (!velox::connector::hasConnectorFactory( - velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - velox::connector::registerConnectorFactory( - std::make_shared( - kHiveHadoop2ConnectorName)); - } - if (!velox::connector::hasConnectorFactory( - velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - } -} - std::vector PrestoServer::registerConnectors( const fs::path& configDirectoryPath) { static const std::string kPropertiesExtension = ".properties"; diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index 9fa1301c1bf1d..ff3032456362d 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -146,8 +146,6 @@ class PrestoServer { virtual void unregisterFileReadersAndWriters(); - virtual void registerConnectorFactories(); - /// Invoked by presto shutdown procedure to unregister connectors. virtual void unregisterConnectors(); diff --git a/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt new file mode 100644 index 0000000000000..3729f2a0481c6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt @@ -0,0 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_connector Registration.cpp PrestoToVeloxConnector.cpp + SystemConnector.cpp) + +if(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + add_subdirectory(arrow_flight) + target_link_libraries(presto_connector presto_flight_connector) +endif() + +target_link_libraries(presto_connector presto_types) \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp similarity index 99% rename from presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp rename to presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp index 38c8e785663ad..76913748569e7 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp @@ -12,12 +12,14 @@ * limitations under the License. */ -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/main/types/TypeParser.h" #include "presto_cpp/presto_protocol/connector/hive/HiveConnectorProtocol.h" #include "presto_cpp/presto_protocol/connector/iceberg/IcebergConnectorProtocol.h" #include "presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h" -#include +#include #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/HiveDataSink.h" diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h similarity index 99% rename from presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h rename to presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h index eb33dfb54ca1d..eed81e4cc00f3 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h @@ -13,8 +13,6 @@ */ #pragma once -#include "PrestoToVeloxExpr.h" -#include "presto_cpp/main/types/TypeParser.h" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.h" #include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" #include "velox/connectors/Connector.h" @@ -25,6 +23,8 @@ namespace facebook::presto { class PrestoToVeloxConnector; +class TypeParser; +class VeloxExprConverter; void registerPrestoToVeloxConnector( std::unique_ptr connector); diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp new file mode 100644 index 0000000000000..d6f6555fb8a22 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/Registration.h" +#include "presto_cpp/main/connectors/SystemConnector.h" + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h" +#endif + +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/tpch/TpchConnector.h" + +namespace facebook::presto { +namespace { + +constexpr char const* kHiveHadoop2ConnectorName = "hive-hadoop2"; +constexpr char const* kIcebergConnectorName = "iceberg"; + +void registerConnectorFactories() { + // These checks for connector factories can be removed after we remove the + // registrations from the Velox library. + if (!velox::connector::hasConnectorFactory( + velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + velox::connector::registerConnectorFactory( + std::make_shared( + kHiveHadoop2ConnectorName)); + } + if (!velox::connector::hasConnectorFactory( + velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + } + + // Register Velox connector factory for iceberg. + // The iceberg catalog is handled by the hive connector factory. + if (!velox::connector::hasConnectorFactory(kIcebergConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared( + kIcebergConnectorName)); + } + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR + if (!velox::connector::hasConnectorFactory( + ArrowFlightConnectorFactory::kArrowFlightConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + } +#endif +} +} // namespace + +void registerConnectors() { + registerConnectorFactories(); + + registerPrestoToVeloxConnector(std::make_unique( + velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)); + registerPrestoToVeloxConnector( + std::make_unique(kHiveHadoop2ConnectorName)); + registerPrestoToVeloxConnector( + std::make_unique(kIcebergConnectorName)); + registerPrestoToVeloxConnector(std::make_unique( + velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)); + + // Presto server uses system catalog or system schema in other catalogs + // in different places in the code. All these resolve to the SystemConnector. + // Depending on where the operator or column is used, different prefixes can + // be used in the naming. So the protocol class is mapped + // to all the different prefixes for System tables/columns. + registerPrestoToVeloxConnector( + std::make_unique("$system")); + registerPrestoToVeloxConnector( + std::make_unique("system")); + registerPrestoToVeloxConnector( + std::make_unique("$system@system")); + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR + registerPrestoToVeloxConnector(std::make_unique( + ArrowFlightConnectorFactory::kArrowFlightConnectorName)); +#endif +} +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.h b/presto-native-execution/presto_cpp/main/connectors/Registration.h new file mode 100644 index 0000000000000..c95aefaacfcaa --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.h @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace facebook::presto { + +void registerConnectors(); + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/SystemConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp similarity index 99% rename from presto-native-execution/presto_cpp/main/SystemConnector.cpp rename to presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp index 7622d203e8689..eb9fb48196e9d 100644 --- a/presto-native-execution/presto_cpp/main/SystemConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "presto_cpp/main/SystemConnector.h" +#include "presto_cpp/main/connectors/SystemConnector.h" #include "presto_cpp/main/PrestoTask.h" #include "presto_cpp/main/TaskManager.h" diff --git a/presto-native-execution/presto_cpp/main/SystemConnector.h b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h similarity index 98% rename from presto-native-execution/presto_cpp/main/SystemConnector.h rename to presto-native-execution/presto_cpp/main/connectors/SystemConnector.h index 52d9df595f736..e7ffd7f2519b6 100644 --- a/presto-native-execution/presto_cpp/main/SystemConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h @@ -13,13 +13,12 @@ */ #pragma once -#include "presto_cpp/main/SystemSplit.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/SystemSplit.h" #include "velox/connectors/Connector.h" namespace facebook::presto { - class TaskManager; class SystemColumnHandle : public velox::connector::ColumnHandle { diff --git a/presto-native-execution/presto_cpp/main/SystemSplit.h b/presto-native-execution/presto_cpp/main/connectors/SystemSplit.h similarity index 100% rename from presto-native-execution/presto_cpp/main/SystemSplit.h rename to presto-native-execution/presto_cpp/main/connectors/SystemSplit.h diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp new file mode 100644 index 0000000000000..03c2db85d0635 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h" + +namespace facebook::presto { + +std::string ArrowFlightConfig::authenticatorName() { + return config_->get(kAuthenticatorName, "none"); +} + +std::optional ArrowFlightConfig::defaultServerHostname() { + return static_cast>( + config_->get(kDefaultServerHost)); +} + +std::optional ArrowFlightConfig::defaultServerPort() { + return static_cast>( + config_->get(kDefaultServerPort)); +} + +bool ArrowFlightConfig::defaultServerSslEnabled() { + return config_->get(kDefaultServerSslEnabled, false); +} + +bool ArrowFlightConfig::serverVerify() { + return config_->get(kServerVerify, true); +} + +std::optional ArrowFlightConfig::serverSslCertificate() { + return static_cast>( + config_->get(kServerSslCertificate)); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h new file mode 100644 index 0000000000000..59e006175976e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/config/Config.h" + +namespace facebook::presto { + +class ArrowFlightConfig { + public: + explicit ArrowFlightConfig( + std::shared_ptr config) + : config_{config} {} + + static constexpr const char* kAuthenticatorName = + "arrow-flight.authenticator.name"; + + static constexpr const char* kDefaultServerHost = "arrow-flight.server"; + + static constexpr const char* kDefaultServerPort = "arrow-flight.server.port"; + + static constexpr const char* kDefaultServerSslEnabled = + "arrow-flight.server-ssl-enabled"; + + static constexpr const char* kServerVerify = "arrow-flight.server.verify"; + + static constexpr const char* kServerSslCertificate = + "arrow-flight.server-ssl-certificate"; + + std::string authenticatorName(); + + std::optional defaultServerHostname(); + + std::optional defaultServerPort(); + + bool defaultServerSslEnabled(); + + bool serverVerify(); + + std::optional serverSslCertificate(); + + private: + std::shared_ptr config_; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp new file mode 100644 index 0000000000000..3c798d94379b1 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp @@ -0,0 +1,188 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include +#include +#include +#include +#include "presto_cpp/main/common/ConfigReader.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/vector/arrow/Bridge.h" + +using namespace arrow::flight; +using namespace facebook::velox::connector; + +namespace facebook::presto { + +// Wrapper for CallOptions which does not add any member variables, +// but provides a write-only interface for adding call headers. +class CallOptionsAddHeaders : public FlightCallOptions, public AddCallHeaders { + public: + void AddHeader(const std::string& key, const std::string& value) override { + headers.emplace_back(key, value); + } +}; + +std::optional ArrowFlightConnector::getDefaultLocation( + const std::shared_ptr& config) { + auto defaultHost = config->defaultServerHostname(); + auto defaultPort = config->defaultServerPort(); + if (!defaultHost.has_value() || !defaultPort.has_value()) { + return std::nullopt; + } + + bool defaultSslEnabled = config->defaultServerSslEnabled(); + AFC_RETURN_OR_RAISE( + defaultSslEnabled + ? Location::ForGrpcTls(defaultHost.value(), defaultPort.value()) + : Location::ForGrpcTcp(defaultHost.value(), defaultPort.value())); +} + +std::shared_ptr +ArrowFlightConnector::initClientOpts( + const std::shared_ptr& config) { + auto clientOpts = std::make_shared(); + clientOpts->disable_server_verification = !config->serverVerify(); + + auto certPath = config->serverSslCertificate(); + if (certPath.has_value()) { + std::ifstream file(certPath.value()); + VELOX_CHECK(file.is_open(), "Could not open TLS certificate"); + std::string cert( + (std::istreambuf_iterator(file)), + (std::istreambuf_iterator())); + clientOpts->tls_root_certs = cert; + } + + return clientOpts; +} + +ArrowFlightDataSource::ArrowFlightDataSource( + const velox::RowTypePtr& outputType, + const std::unordered_map>& + columnHandles, + std::shared_ptr authenticator, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& flightConfig, + const std::shared_ptr& clientOpts, + const std::optional& defaultLocation) + : outputType_{outputType}, + authenticator_{std::move(authenticator)}, + connectorQueryCtx_{connectorQueryCtx}, + flightConfig_{flightConfig}, + clientOpts_{clientOpts}, + defaultLocation_{defaultLocation} { + // columnMapping_ contains the real column names in the expected order. + // This is later used by projectOutputColumns to filter out unnecessary + // columns from the fetched chunk. + columnMapping_.reserve(outputType_->size()); + + for (const auto& columnName : outputType_->names()) { + auto it = columnHandles.find(columnName); + VELOX_CHECK( + it != columnHandles.end(), + "missing columnHandle for column '{}'", + columnName); + + auto handle = + std::dynamic_pointer_cast(it->second); + VELOX_CHECK_NOT_NULL( + handle, + "handle for column '{}' is not an ArrowFlightColumnHandle", + columnName); + + columnMapping_.push_back(handle->name()); + } +} + +void ArrowFlightDataSource::addSplit(std::shared_ptr split) { + auto flightSplit = std::dynamic_pointer_cast(split); + VELOX_CHECK( + flightSplit, "ArrowFlightDataSource received wrong type of split"); + + auto flightEndpointStr = + folly::base64Decode(flightSplit->flightEndpointBytes_); + + FlightEndpoint flightEndpoint; + AFC_ASSIGN_OR_RAISE( + flightEndpoint, + arrow::flight::FlightEndpoint::Deserialize(flightEndpointStr)); + + Location loc; + if (!flightEndpoint.locations.empty()) { + loc = flightEndpoint.locations[0]; + } else { + VELOX_CHECK( + defaultLocation_.has_value(), + "No location from Flight endpoint, default host or port is missing"); + loc = defaultLocation_.value(); + } + VELOX_CHECK_NOT_NULL(clientOpts_, "FlightClientOptions is not initialized"); + + AFC_ASSIGN_OR_RAISE(auto client, FlightClient::Connect(loc, *clientOpts_)); + + CallOptionsAddHeaders callOptsAddHeaders{}; + authenticator_->authenticateClient( + client, connectorQueryCtx_->sessionProperties(), callOptsAddHeaders); + + auto readerResult = client->DoGet(callOptsAddHeaders, flightEndpoint.ticket); + AFC_ASSIGN_OR_RAISE(currentReader_, readerResult); +} + +std::optional ArrowFlightDataSource::next( + uint64_t size, + velox::ContinueFuture& /* unused */) { + VELOX_CHECK_NOT_NULL(currentReader_, "Missing split, call addSplit() first"); + + AFC_ASSIGN_OR_RAISE(auto chunk, currentReader_->Next()); + + // Null values in the chunk indicates that the Flight stream is complete. + if (!chunk.data) { + currentReader_ = nullptr; + return nullptr; + } + + // Extract only required columns from the record batch as a velox RowVector. + auto output = projectOutputColumns(chunk.data); + + completedRows_ += output->size(); + completedBytes_ += output->inMemoryBytes(); + return output; +} + +velox::RowVectorPtr ArrowFlightDataSource::projectOutputColumns( + const std::shared_ptr& input) { + velox::memory::MemoryPool* pool = connectorQueryCtx_->memoryPool(); + std::vector children; + children.reserve(columnMapping_.size()); + + // Extract and convert desired columns in the correct order. + for (const auto& name : columnMapping_) { + auto column = input->GetColumnByName(name); + VELOX_CHECK_NOT_NULL(column, "column with name '{}' not found", name); + ArrowArray array; + ArrowSchema schema; + AFC_RAISE_NOT_OK(arrow::ExportArray(*column, &array, &schema)); + children.push_back(velox::importFromArrowAsOwner(schema, array, pool)); + } + + return std::make_shared( + pool, + outputType_, + velox::BufferPtr() /*nulls*/, + input->num_rows(), + std::move(children)); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h new file mode 100644 index 0000000000000..370e7d3b40061 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h @@ -0,0 +1,194 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h" +#include "presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h" +#include "velox/connectors/Connector.h" + +namespace facebook::presto { + +class ArrowFlightTableHandle : public velox::connector::ConnectorTableHandle { + public: + explicit ArrowFlightTableHandle(const std::string& connectorId) + : ConnectorTableHandle(connectorId) {} +}; + +struct ArrowFlightSplit : public velox::connector::ConnectorSplit { + /// @param connectorId + /// @param flightEndpointBytes Base64 Serialized `FlightEndpoint` + ArrowFlightSplit( + const std::string& connectorId, + const std::string& flightEndpointBytes) + : ConnectorSplit(connectorId), + flightEndpointBytes_(flightEndpointBytes) {} + + const std::string flightEndpointBytes_; +}; + +class ArrowFlightColumnHandle : public velox::connector::ColumnHandle { + public: + explicit ArrowFlightColumnHandle(const std::string& columnName) + : columnName_(columnName) {} + + const std::string& name() { + return columnName_; + } + + private: + std::string columnName_; +}; + +class ArrowFlightDataSource : public velox::connector::DataSource { + public: + ArrowFlightDataSource( + const velox::RowTypePtr& outputType, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + std::shared_ptr authenticator, + const velox::connector::ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& flightConfig, + const std::shared_ptr& clientOpts, + const std::optional& defaultLocation = + std::nullopt); + + void addSplit( + std::shared_ptr split) override; + + std::optional next( + uint64_t size, + velox::ContinueFuture& /* unused */) override; + + void addDynamicFilter( + velox::column_index_t outputChannel, + const std::shared_ptr& filter) override { + VELOX_NYI("This connector doesn't support dynamic filters"); + } + + uint64_t getCompletedBytes() override { + return completedBytes_; + } + + uint64_t getCompletedRows() override { + return completedRows_; + } + + std::unordered_map runtimeStats() + override { + return {}; + } + + private: + /// Convert an Arrow record batch to Velox RowVector. + /// Process only those columns that are present in outputType_. + velox::RowVectorPtr projectOutputColumns( + const std::shared_ptr& input); + + velox::RowTypePtr outputType_; + std::vector columnMapping_; + std::unique_ptr currentReader_; + uint64_t completedRows_ = 0; + uint64_t completedBytes_ = 0; + std::shared_ptr authenticator_; + const velox::connector::ConnectorQueryCtx* const connectorQueryCtx_; + const std::shared_ptr flightConfig_; + const std::shared_ptr clientOpts_; + const std::optional defaultLocation_; +}; + +class ArrowFlightConnector : public velox::connector::Connector { + public: + explicit ArrowFlightConnector( + const std::string& id, + std::shared_ptr config, + const char* authenticatorName = nullptr) + : Connector(id), + flightConfig_(std::make_shared(config)), + clientOpts_(initClientOpts(flightConfig_)), + defaultLocation_(getDefaultLocation(flightConfig_)), + authenticator_(getAuthenticatorFactory( + authenticatorName + ? authenticatorName + : flightConfig_->authenticatorName()) + ->newAuthenticator(config)) {} + + std::unique_ptr createDataSource( + const velox::RowTypePtr& outputType, + const std::shared_ptr& + tableHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + velox::connector::ConnectorQueryCtx* connectorQueryCtx) override { + return std::make_unique( + outputType, + columnHandles, + authenticator_, + connectorQueryCtx, + flightConfig_, + clientOpts_, + defaultLocation_); + } + + std::unique_ptr createDataSink( + velox::RowTypePtr inputType, + std::shared_ptr + connectorInsertTableHandle, + velox::connector::ConnectorQueryCtx* connectorQueryCtx, + velox::connector::CommitStrategy commitStrategy) override { + VELOX_NYI("The arrow-flight connector does not support a DataSink"); + } + + private: + // Returns the default location specified in the ArrowFlightConfig. + // Returns nullopt if either host or port is missing. + static std::optional getDefaultLocation( + const std::shared_ptr& config); + + static std::shared_ptr initClientOpts( + const std::shared_ptr& config); + + const std::shared_ptr flightConfig_; + const std::shared_ptr clientOpts_; + const std::optional defaultLocation_; + const std::shared_ptr authenticator_; +}; + +class ArrowFlightConnectorFactory : public velox::connector::ConnectorFactory { + public: + static constexpr const char* kArrowFlightConnectorName = "arrow-flight"; + + ArrowFlightConnectorFactory() : ConnectorFactory(kArrowFlightConnectorName) {} + + explicit ArrowFlightConnectorFactory( + const char* name, + const char* authenticatorName = nullptr) + : ConnectorFactory(name), authenticatorName_(authenticatorName) {} + + std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + folly::Executor* cpuExecutor = nullptr) override { + return std::make_shared( + id, config, authenticatorName_); + } + + private: + const char* authenticatorName_{nullptr}; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp new file mode 100644 index 0000000000000..1ac5ab838f1db --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h" +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h" + +namespace facebook::presto { + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* const connectorSplit, + const protocol::SplitContext* /*splitContext*/) const { + auto arrowSplit = + dynamic_cast(connectorSplit); + VELOX_CHECK_NOT_NULL( + arrowSplit, "Unexpected split type {}", connectorSplit->_type); + return std::make_unique( + catalogId, arrowSplit->flightEndpointBytes); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& /*typeParser*/) const { + auto arrowColumn = + dynamic_cast(column); + VELOX_CHECK_NOT_NULL( + arrowColumn, "Unexpected column handle type {}", column->_type); + return std::make_unique( + arrowColumn->columnName); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& /*exprConverter*/, + const TypeParser& /*typeParser*/, + std::unordered_map< + std::string, + std::shared_ptr>& assignments) const { + return std::make_unique( + tableHandle.connectorId); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h new file mode 100644 index 0000000000000..fa7ab67b9c0b7 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" + +namespace facebook::presto { + +class ArrowPrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit ArrowPrestoToVeloxConnector(std::string connectorName) + : PrestoToVeloxConnector(std::move(connectorName)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, + std::unordered_map< + std::string, + std::shared_ptr>& assignments) + const final; + + std::unique_ptr createConnectorProtocol() + const final; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt new file mode 100644 index 0000000000000..8a973f46b38a2 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt @@ -0,0 +1,44 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +find_package(Arrow REQUIRED) +find_package(PkgConfig REQUIRED) +pkg_check_modules(ARROW_FLIGHT REQUIRED IMPORTED_TARGET GLOBAL arrow-flight) + +if(NOT ARROW_FLIGHT_FOUND) + message(FATAL_ERROR "Arrow Flight package not found") +endif() + +set(ArrowFlight_FOUND TRUE) +set(ArrowFlight_INCLUDE_DIRS ${ARROW_FLIGHT_INCLUDE_DIRS}) +set(ArrowFlight_LIBRARIES ${ARROW_FLIGHT_LIBRARIES}) +include_directories(${ArrowFlight_INCLUDE_DIRS}) + +add_subdirectory(auth) + +add_library(presto_flight_connector_utils INTERFACE Macros.h) +target_link_libraries(presto_flight_connector_utils INTERFACE velox_exception) + +add_library( + presto_flight_connector OBJECT + ArrowFlightConnector.cpp ArrowPrestoToVeloxConnector.cpp + ArrowFlightConfig.cpp) + +target_compile_definitions(presto_flight_connector + PUBLIC PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + +target_link_libraries( + presto_flight_connector velox_connector PkgConfig::ARROW_FLIGHT + presto_flight_connector_utils presto_flight_connector_auth presto_types) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h new file mode 100644 index 0000000000000..5ab725e582cc6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/base/Exceptions.h" + +// Macros for dealing with arrow::Status and arrow::Result objects +// and converting them to velox exceptions. + +/// Raise a Velox exception if status is not OK. +/// Counterpart of ARROW_RETURN_NOT_OK. +#define AFC_RAISE_NOT_OK(status) \ + do { \ + ::arrow::Status __s = ::arrow::internal::GenericToStatus(status); \ + VELOX_CHECK(__s.ok(), __s.message()); \ + } while (false) + +#define AFC_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \ + auto&& result_name = (rexpr); \ + VELOX_CHECK((result_name).ok(), (result_name).status().message()); \ + lhs = std::move(result_name).ValueUnsafe(); + +/// Raise a Velox exception if expr doesn't return an OK result, +/// else unwrap the value and assign it to `lhs`. +/// `std::move`s its right hand operand. +/// Counterpart of ARROW_ASSIGN_OR_RAISE. +#define AFC_ASSIGN_OR_RAISE(lhs, rexpr) \ + AFC_ASSIGN_OR_RAISE_IMPL( \ + ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), lhs, rexpr); + +/// Raise a Velox exception if rexpr doesn't return an OK result, +/// else unwrap the value and return it. +/// `std::move`s its right hand operand. +#define AFC_RETURN_OR_RAISE(rexpr) \ + do { \ + auto&& __r = (rexpr); \ + VELOX_CHECK(__r.ok(), __r.status().message()); \ + return std::move(__r).ValueUnsafe(); \ + } while (false) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp new file mode 100644 index 0000000000000..d82ed1f94f3dc --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto { +namespace { +auto& authenticatorFactories() { + static std::unordered_map> + factories; + return factories; +} +} // namespace + +bool registerAuthenticatorFactory( + std::shared_ptr factory) { + bool ok = authenticatorFactories().insert({factory->name(), factory}).second; + VELOX_CHECK( + ok, + "Flight AuthenticatorFactory with name {} is already registered", + factory->name()); + return true; +} + +std::shared_ptr getAuthenticatorFactory( + const std::string& name) { + auto it = authenticatorFactories().find(name); + VELOX_CHECK( + it != authenticatorFactories().end(), + "Flight AuthenticatorFactory with name {} not registered", + name); + return it->second; +} + +AFC_REGISTER_AUTH_FACTORY(std::make_shared()) + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h new file mode 100644 index 0000000000000..db51bb3208a67 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/common/config/Config.h" + +namespace facebook::presto { + +class Authenticator { + public: + /// @brief Override this method to define implementation-specific + /// authentication This could be through client->Authenticate, or + /// client->AuthenticateBasicToken or any other custom strategy + /// @param client the Flight client which is to be authenticated + /// @param sessionProperties connector session properties + /// @param headerWriter write-only object used to set authentication headers + virtual void authenticateClient( + std::unique_ptr& client, + const velox::config::ConfigBase* sessionProperties, + arrow::flight::AddCallHeaders& headerWriter) = 0; +}; + +class AuthenticatorFactory { + public: + explicit AuthenticatorFactory(std::string_view name) : name_{name} {} + + const std::string& name() const { + return name_; + } + + virtual std::shared_ptr newAuthenticator( + std::shared_ptr config) = 0; + + private: + std::string name_; +}; + +bool registerAuthenticatorFactory( + std::shared_ptr factory); + +std::shared_ptr getAuthenticatorFactory( + const std::string& name); + +#define AFC_REGISTER_AUTH_FACTORY(factory) \ + namespace { \ + static bool FB_ANONYMOUS_VARIABLE(g_ConnectorFactory) = \ + ::facebook::presto::registerAuthenticatorFactory((factory)); \ + } + +class NoOpAuthenticator : public Authenticator { + public: + void authenticateClient( + std::unique_ptr& client, + const velox::config::ConfigBase* sessionProperties, + arrow::flight::AddCallHeaders& headerWriter) override {} +}; + +class NoOpAuthenticatorFactory : public AuthenticatorFactory { + public: + static constexpr const std::string_view kNoOpAuthenticatorName{"none"}; + + NoOpAuthenticatorFactory() : AuthenticatorFactory{kNoOpAuthenticatorName} {} + + explicit NoOpAuthenticatorFactory(std::string_view name) + : AuthenticatorFactory{name} {} + + std::shared_ptr newAuthenticator( + std::shared_ptr config) override { + return std::make_shared(); + } +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt new file mode 100644 index 0000000000000..1e7eba3154a0e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_flight_connector_auth Authenticator.cpp) + +target_link_libraries(presto_flight_connector_auth + presto_flight_connector_utils velox_exception) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp new file mode 100644 index 0000000000000..eb946f1fcae76 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h" +#include + +using namespace facebook::velox; +using namespace facebook::presto; + +TEST(ArrowFlightConfigTest, defaultConfig) { + auto rawConfig = std::make_shared( + std::move(std::unordered_map{})); + auto config = ArrowFlightConfig(rawConfig); + ASSERT_EQ(config.authenticatorName(), "none"); + ASSERT_EQ(config.defaultServerHostname(), std::nullopt); + ASSERT_EQ(config.defaultServerPort(), std::nullopt); + ASSERT_EQ(config.defaultServerSslEnabled(), false); + ASSERT_EQ(config.serverVerify(), true); + ASSERT_EQ(config.serverSslCertificate(), std::nullopt); +} + +TEST(ArrowFlightConfigTest, overrideConfig) { + std::unordered_map configMap = { + {ArrowFlightConfig::kAuthenticatorName, "my-authenticator"}, + {ArrowFlightConfig::kDefaultServerHost, "my-server-host"}, + {ArrowFlightConfig::kDefaultServerPort, "9000"}, + {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, + {ArrowFlightConfig::kServerVerify, "false"}, + {ArrowFlightConfig::kServerSslCertificate, "my-cert.crt"}}; + auto config = ArrowFlightConfig( + std::make_shared(std::move(configMap))); + ASSERT_EQ(config.authenticatorName(), "my-authenticator"); + ASSERT_EQ(config.defaultServerHostname(), "my-server-host"); + ASSERT_EQ(config.defaultServerPort(), 9000); + ASSERT_EQ(config.defaultServerSslEnabled(), true); + ASSERT_EQ(config.serverVerify(), false); + ASSERT_EQ(config.serverSslCertificate(), "my-cert.crt"); +} diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp new file mode 100644 index 0000000000000..0e2139d1eafec --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp @@ -0,0 +1,236 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class TestingServerMiddlewareFactory : public flight::ServerMiddlewareFactory { + public: + static constexpr const char* kAuthHeader = "authorization"; + static constexpr const char* kAuthToken = "Bearer 1234"; + static constexpr const char* kAuthTokenUnauthorized = "Bearer 2112"; + + arrow::Status StartCall( + const flight::CallInfo& info, + const flight::ServerCallContext& context, + std::shared_ptr* middleware) override { + auto iter = context.incoming_headers().find(kAuthHeader); + + if (iter == context.incoming_headers().end()) { + return flight::MakeFlightError( + flight::FlightStatusCode::Unauthenticated, + "Authorization token not provided"); + } else { + std::lock_guard l(mutex_); + checkedTokens_.emplace_back(iter->second); + } + + if (kAuthToken != iter->second) { + return flight::MakeFlightError( + flight::FlightStatusCode::Unauthorized, + "Authorization token is invalid"); + } + + return arrow::Status::OK(); + } + + bool isTokenChecked(const std::string& authToken) { + { + std::lock_guard l(mutex_); + return std::find( + checkedTokens_.begin(), checkedTokens_.end(), authToken) != + checkedTokens_.end(); + } + } + + private: + std::string validToken_; + std::vector checkedTokens_; + std::mutex mutex_; +}; + +class TestingAuthenticator : public Authenticator { + public: + explicit TestingAuthenticator(const std::string& authToken) + : authToken_(authToken) {} + + void authenticateClient( + std::unique_ptr& client, + const velox::config::ConfigBase* sessionProperties, + arrow::flight::AddCallHeaders& headerWriter) override { + if (!authToken_.empty()) { + headerWriter.AddHeader( + TestingServerMiddlewareFactory::kAuthHeader, authToken_); + } + } + + private: + std::string authToken_; +}; + +class TestingAuthenticatorFactory : public AuthenticatorFactory { + public: + TestingAuthenticatorFactory( + const std::string& name, + const std::string& authToken) + : AuthenticatorFactory(name), + testingAuthenticator_{ + std::make_shared(authToken)} {} + + std::shared_ptr newAuthenticator( + std::shared_ptr config) override { + return testingAuthenticator_; + } + + private: + std::shared_ptr testingAuthenticator_; +}; + +namespace { +constexpr const char* kAuthFactoryName = "testing-auth-valid"; +constexpr const char* kAuthFactoryUnauthorizedName = + "testing-auth-unauthorized"; +constexpr const char* kAuthFactoryNoTokenName = "testing-auth-no-token"; + +bool registerTestAuthFactories() { + static bool once = [] { + auto authFactory = std::make_shared( + kAuthFactoryName, TestingServerMiddlewareFactory::kAuthToken); + registerAuthenticatorFactory(authFactory); + auto authFactoryUnauthorized = + std::make_shared( + kAuthFactoryUnauthorizedName, + TestingServerMiddlewareFactory::kAuthTokenUnauthorized); + registerAuthenticatorFactory(authFactoryUnauthorized); + auto authFactoryNoToken = std::make_shared( + kAuthFactoryNoTokenName, ""); + registerAuthenticatorFactory(authFactoryNoToken); + return true; + }(); + return once; +} +} // namespace + +class ArrowFlightConnectorAuthTestBase : public FlightWithServerTestBase { + public: + explicit ArrowFlightConnectorAuthTestBase(const std::string& authFactoryName) + : FlightWithServerTestBase(std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kAuthenticatorName, authFactoryName}})), + testingMiddlewareFactory_( + std::make_shared()) {} + + void SetUp() override { + registerTestAuthFactories(); + FlightWithServerTestBase::SetUp(); + } + + void setFlightServerOptions( + flight::FlightServerOptions* serverOptions) override { + serverOptions->middleware.push_back( + {"bearer-auth-server", testingMiddlewareFactory_}); + } + + core::PlanNodePtr addSampleDataAndRunQuery() { + updateTable( + "sample-data", + makeArrowTable( + {"id", "value"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()}), + makeNumericArray( + {41, 42, 43, std::numeric_limits::min()})})); + + return ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + } + + protected: + std::shared_ptr testingMiddlewareFactory_; +}; + +class ArrowFlightConnectorAuthTest : public ArrowFlightConnectorAuthTestBase { + public: + ArrowFlightConnectorAuthTest() + : ArrowFlightConnectorAuthTestBase(kAuthFactoryName) {} +}; + +TEST_F(ArrowFlightConnectorAuthTest, customAuthenticator) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + auto idVec = + makeFlatVector({1, 12, 2, std::numeric_limits::max()}); + auto valueVec = makeFlatVector( + {41, 42, 43, std::numeric_limits::min()}); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + ASSERT_TRUE(testingMiddlewareFactory_->isTokenChecked( + TestingServerMiddlewareFactory::kAuthToken)); +} + +class ArrowFlightConnectorUnauthorizedTest + : public ArrowFlightConnectorAuthTestBase { + public: + ArrowFlightConnectorUnauthorizedTest() + : ArrowFlightConnectorAuthTestBase(kAuthFactoryUnauthorizedName) {} +}; + +TEST_F(ArrowFlightConnectorUnauthorizedTest, unauthorizedToken) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertEmptyResults(), + "Unauthorized"); +} + +class ArrowFlightConnectorUnauthenticatedTest + : public ArrowFlightConnectorAuthTestBase { + public: + ArrowFlightConnectorUnauthenticatedTest() + : ArrowFlightConnectorAuthTestBase(kAuthFactoryNoTokenName) {} +}; + +TEST_F(ArrowFlightConnectorUnauthenticatedTest, unauthenticatedNoToken) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertEmptyResults(), + "Unauthenticated"); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp new file mode 100644 index 0000000000000..72c87a4f3b53c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp @@ -0,0 +1,328 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class ArrowFlightConnectorDataTypeTest : public FlightWithServerTestBase {}; + +TEST_F(ArrowFlightConnectorDataTypeTest, booleanType) { + updateTable( + "sample-data", + makeArrowTable( + {"bool_col"}, {makeBooleanArray({true, false, true, false})})); + + auto boolVec = makeFlatVector({true, false, true, false}); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"bool_col"}, {velox::BOOLEAN()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({boolVec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, integerTypes) { + std::vector tinyData = { + -128, 0, 127, std::numeric_limits::max()}; + std::vector smallData = { + -32768, 0, 32767, std::numeric_limits::max()}; + std::vector intData = { + -2147483648, 0, 2147483647, std::numeric_limits::max()}; + std::vector bigData = { + -3435678987654321234LL, + 0, + 4527897896541234567LL, + std::numeric_limits::max()}; + + updateTable( + "sample-data", + makeArrowTable( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {makeNumericArray(tinyData), + makeNumericArray(smallData), + makeNumericArray(intData), + makeNumericArray(bigData)})); + + auto tinyintVec = makeFlatVector(tinyData); + + auto smallintVec = makeFlatVector(smallData); + + auto integerVec = makeFlatVector(intData); + + auto bigintVec = makeFlatVector(bigData); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {velox::TINYINT(), + velox::SMALLINT(), + velox::INTEGER(), + velox::BIGINT()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults( + makeRowVector({tinyintVec, smallintVec, integerVec, bigintVec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, realType) { + std::vector realData = { + std::numeric_limits::min(), + 0.0f, + 3.14f, + std::numeric_limits::max()}; + std::vector doubleData = { + std::numeric_limits::min(), + 0.0, + 3.14159, + std::numeric_limits::max()}; + + updateTable( + "sample-data", + makeArrowTable( + {"real_col", "double_col"}, + {makeNumericArray(realData), + makeNumericArray(doubleData)})); + + auto realVec = makeFlatVector(realData); + auto doubleVec = makeFlatVector(doubleData); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"real_col", "double_col"}, {velox::REAL(), velox::DOUBLE()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({realVec, doubleVec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, varcharType) { + updateTable( + "sample-data", + makeArrowTable( + {"varchar_col"}, {makeStringArray({"Hello", "World", "India"})})); + + auto vec = makeFlatVector( + {facebook::velox::StringView("Hello"), + facebook::velox::StringView("World"), + facebook::velox::StringView("India")}); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"varchar_col"}, {velox::VARCHAR()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({vec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, timestampType) { + auto timestampValues = + std::vector{1622538000, 1622541600, 1622545200}; + + updateTable( + "sample-data", + makeArrowTable( + {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, + {makeTimestampArray(timestampValues, arrow::TimeUnit::SECOND), + makeTimestampArray(timestampValues, arrow::TimeUnit::MILLI), + makeTimestampArray(timestampValues, arrow::TimeUnit::MICRO)})); + + std::vector veloxTimestampSec; + for (const auto& ts : timestampValues) { + veloxTimestampSec.emplace_back(ts, 0); // Assuming 0 microseconds part + } + + auto timestampSecCol = + makeFlatVector(veloxTimestampSec); + + std::vector veloxTimestampMilli; + for (const auto& ts : timestampValues) { + veloxTimestampMilli.emplace_back( + ts / 1000, (ts % 1000) * 1000000); // Convert to seconds and nanoseconds + } + + auto timestampMilliCol = + makeFlatVector(veloxTimestampMilli); + + std::vector veloxTimestampMicro; + for (const auto& ts : timestampValues) { + veloxTimestampMicro.emplace_back( + ts / 1000000, + (ts % 1000000) * 1000); // Convert to seconds and nanoseconds + } + + auto timestampMicroCol = + makeFlatVector(veloxTimestampMicro); + + core::PlanNodePtr plan; + plan = + ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, + {velox::TIMESTAMP(), velox::TIMESTAMP(), velox::TIMESTAMP()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector( + {timestampSecCol, timestampMilliCol, timestampMicroCol})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, dateDayType) { + std::vector datesDay = {18748, 18749, 18750}; // Days since epoch + std::vector datesMilli = { + 1622538000000, 1622541600000, 1622545200000}; // Milliseconds since epoch + + updateTable( + "sample-data", + makeArrowTable( + {"daydate_col", "daymilli_col"}, + {makeNumericArray(datesDay), + makeNumericArray(datesMilli)})); + + auto dateVec = makeFlatVector(datesDay); + auto milliVec = makeFlatVector(datesMilli); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"daydate_col"}, {velox::DATE()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({dateVec})); + + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"daymilli_col"}, {velox::DATE()})) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({milliVec})), + "Unable to convert 'tdm' ArrowSchema format type to Velox"); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, decimalType) { + std::vector decimalValuesBigInt = { + 123456789012345678, + -123456789012345678, + std::numeric_limits::max()}; + std::vector> decimalArrayVec; + decimalArrayVec.push_back(makeDecimalArray(decimalValuesBigInt, 18, 2)); + updateTable( + "sample-data", makeArrowTable({"decimal_col_bigint"}, decimalArrayVec)); + auto decimalVecBigInt = makeFlatVector(decimalValuesBigInt); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"decimal_col_bigint"}, + {velox::DECIMAL(18, 2)})) // precision can't be 0 and < scale + .planNode(); + + // Execute the query and assert the results + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({decimalVecBigInt})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, allTypes) { + auto timestampValues = + std::vector{1622550000, 1622553600, 1622557200}; + + auto sampleTable = makeArrowTable( + {"id", + "daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {makeNumericArray({1, 2, 3}), + makeNumericArray({18748, 18749, 18750}), + makeTimestampArray(timestampValues, arrow::TimeUnit::SECOND), + makeStringArray({"apple", "banana", "cherry"}), + makeNumericArray({3.14, 2.718, 1.618}), + makeNumericArray( + {-32768, 32767, std::numeric_limits::max()}), + makeBooleanArray({true, false, true})}); + + updateTable("gen-data", sampleTable); + + auto dateVec = makeFlatVector({18748, 18749, 18750}); + + std::vector veloxTimestampSec; + for (const auto& ts : timestampValues) { + veloxTimestampSec.emplace_back(ts, 0); // Assuming 0 microseconds part + } + auto timestampSecVec = + makeFlatVector(veloxTimestampSec); + + auto stringVec = makeFlatVector( + {facebook::velox::StringView("apple"), + facebook::velox::StringView("banana"), + facebook::velox::StringView("cherry")}); + auto realVec = makeFlatVector({3.14, 2.718, 1.618}); + auto intVec = makeFlatVector( + {-32768, 32767, std::numeric_limits::max()}); + auto boolVec = makeFlatVector({true, false, true}); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {velox::DATE(), + velox::TIMESTAMP(), + velox::VARCHAR(), + velox::DOUBLE(), + velox::INTEGER(), + velox::BOOLEAN()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"gen-data"})) + .assertResults(makeRowVector( + {dateVec, timestampSecVec, stringVec, realVec, intVec, boolVec})); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp new file mode 100644 index 0000000000000..d538b591027c4 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class ArrowFlightConnectorTest : public FlightWithServerTestBase {}; + +TEST_F(ArrowFlightConnectorTest, invalidSplit) { + auto plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({{"id", velox::BIGINT()}})) + .planNode(); + + VELOX_ASSERT_THROW( + velox::exec::test::AssertQueryBuilder(plan) + .splits(makeSplits({"unknown"})) + .copyResults(pool()), + "table does not exist"); +} + +TEST_F(ArrowFlightConnectorTest, dataSourceCreation) { + // missing columnHandle test + auto plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()}), + {{"id", std::make_shared("id")}}, + false /*createDefaultColumnHandles*/) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .copyResults(pool()), + "missing columnHandle for column 'value'"); +} + +TEST_F(ArrowFlightConnectorTest, dataSource) { + std::vector idData = {1, 12, 2, std::numeric_limits::max()}; + std::vector valueData = { + 41, 42, 43, std::numeric_limits::min()}; + std::vector unsignedData = { + 41, 42, 43, std::numeric_limits::min()}; + + updateTable( + "sample-data", + makeArrowTable( + {"id", "value", "unsigned"}, + {makeNumericArray(idData), + makeNumericArray(valueData), + // note that velox doesn't support unsigned types + // connector should still be able to query such tables + // as long as this specific column isn't requested. + makeNumericArray(unsignedData)})); + + auto idColumn = std::make_shared("id"); + auto idVec = makeFlatVector(idData); + + auto valueColumn = std::make_shared("value"); + auto valueVec = makeFlatVector(valueData); + + core::PlanNodePtr plan; + + // direct test + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, std::vector{})) + .assertResults(makeRowVector({idVec, valueVec})), + "default host or port is missing"); + + // column alias test + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"ducks", "id"}, {velox::BIGINT(), velox::BIGINT()}), + {{"ducks", idColumn}}) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, idVec})); + + // invalid columnHandle test + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"ducks", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .copyResults(pool()), + "column with name 'ducks' not found"); +} + +class ArrowFlightConnectorTestDefaultServer : public FlightWithServerTestBase { + public: + ArrowFlightConnectorTestDefaultServer() + : FlightWithServerTestBase(std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kDefaultServerHost, CONNECT_HOST}, + {ArrowFlightConfig::kDefaultServerPort, + std::to_string(LISTEN_PORT)}})) {} +}; + +TEST_F(ArrowFlightConnectorTestDefaultServer, dataSource) { + std::vector idData = {1, 12, 2, std::numeric_limits::max()}; + std::vector valueData = { + 41, 42, 43, std::numeric_limits::min()}; + + updateTable( + "sample-data", + makeArrowTable( + {"id", "value"}, + {makeNumericArray(idData), + makeNumericArray(valueData)})); + + auto idColumn = std::make_shared("id"); + auto idVec = makeFlatVector(idData); + + auto valueColumn = std::make_shared("value"); + auto valueVec = makeFlatVector(valueData); + + core::PlanNodePtr plan; + + // direct test + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + AssertQueryBuilder(plan) + .splits(makeSplits( + {"sample-data"}, + std::vector{})) // Using default connector + .assertResults(makeRowVector({idVec, valueVec})); +} + +} // namespace facebook::presto::test + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp new file mode 100644 index 0000000000000..0c8a5a9e46453 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class ArrowFlightConnectorTlsTestBase : public FlightWithServerTestBase { + protected: + explicit ArrowFlightConnectorTlsTestBase( + std::shared_ptr config) + : FlightWithServerTestBase(std::move(config)) {} + + flight::Location getServerLocation() override { + AFC_ASSIGN_OR_RAISE( + auto loc, flight::Location::ForGrpcTls(BIND_HOST, LISTEN_PORT)); + return loc; + } + + void setFlightServerOptions( + flight::FlightServerOptions* serverOptions) override { + flight::CertKeyPair tlsCertificate{ + .pem_cert = readFile("./data/tls_certs/server.crt"), + .pem_key = readFile("./data/tls_certs/server.key")}; + serverOptions->tls_certificates.push_back(tlsCertificate); + } + + void executeTest( + bool isPositiveTest = true, + const std::string& expectedError = "") { + std::vector idData = { + 1, 12, 2, std::numeric_limits::max()}; + + updateTable( + "sample-data", + makeArrowTable({"id"}, {makeNumericArray(idData)})); + + auto idVec = makeFlatVector(idData); + + auto plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"id"}, {velox::BIGINT()})) + .planNode(); + + ASSERT_OK_AND_ASSIGN( + auto loc, flight::Location::ForGrpcTls(CONNECT_HOST, LISTEN_PORT)); + auto locs = std::vector{loc}; + if (isPositiveTest) { + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, locs)) + .assertResults(makeRowVector({idVec})); + } else { + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, locs)) + .assertResults(makeRowVector({idVec})), + expectedError); + } + } +}; + +class ArrowFlightConnectorTlsTest : public ArrowFlightConnectorTlsTestBase { + protected: + explicit ArrowFlightConnectorTlsTest() + : ArrowFlightConnectorTlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kServerVerify, "true"}, + {ArrowFlightConfig::kServerSslCertificate, + "./data/tls_certs/ca.crt"}})) {} +}; + +TEST_F(ArrowFlightConnectorTlsTest, tlsEnabled) { + executeTest(); +} + +class ArrowFlightTlsNoCertValidationTest + : public ArrowFlightConnectorTlsTestBase { + protected: + explicit ArrowFlightTlsNoCertValidationTest() + : ArrowFlightConnectorTlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kServerVerify, "false"}})) {} +}; + +TEST_F(ArrowFlightTlsNoCertValidationTest, tlsNoCertValidation) { + executeTest(); +} + +class ArrowFlightTlsNoCertTest : public ArrowFlightConnectorTlsTestBase { + protected: + ArrowFlightTlsNoCertTest() + : ArrowFlightConnectorTlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kServerVerify, "true"}})) {} +}; + +TEST_F(ArrowFlightTlsNoCertTest, tlsNoCert) { + executeTest(false, "handshake failed"); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt new file mode 100644 index 0000000000000..9af596a913973 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_subdirectory(utils) + +add_executable(presto_flight_connector_infra_test + TestingArrowFlightServerTest.cpp) + +add_test(presto_flight_connector_infra_test presto_flight_connector_infra_test) + +target_link_libraries( + presto_flight_connector_infra_test presto_protocol + presto_flight_connector_test_lib GTest::gtest GTest::gtest_main ${GLOG}) + +add_executable( + presto_flight_connector_test + ArrowFlightConnectorTest.cpp ArrowFlightConnectorAuthTest.cpp + ArrowFlightConnectorTlsTest.cpp ArrowFlightConnectorDataTypeTest.cpp + ArrowFlightConfigTest.cpp) + +set(DATA_DIR "${CMAKE_CURRENT_SOURCE_DIR}/data/tls_certs") + +add_custom_target( + copy_flight_test_data ALL + COMMAND ${CMAKE_COMMAND} -E copy_directory ${DATA_DIR} + $/data/tls_certs) + +add_test(presto_flight_connector_test presto_flight_connector_test) + +target_link_libraries( + presto_flight_connector_test + velox_exec_test_lib + presto_flight_connector + gtest + gtest_main + presto_flight_connector_test_lib + presto_protocol) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestingArrowFlightServerTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestingArrowFlightServerTest.cpp new file mode 100644 index 0000000000000..306a80776123e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestingArrowFlightServerTest.cpp @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h" +#include +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" + +using namespace arrow; + +namespace facebook::presto::test { + +class TestingArrowFlightServerTest : public testing::Test { + public: + static void SetUpTestSuite() { + server = std::make_unique(); + ASSERT_OK_AND_ASSIGN( + auto loc, flight::Location::ForGrpcTcp("127.0.0.1", 0)); + ASSERT_OK(server->Init(flight::FlightServerOptions(loc))); + } + + static void TearDownTestSuite() { + ASSERT_OK(server->Shutdown()); + } + + static void updateTable( + std::string name, + std::shared_ptr table) { + server->updateTable(std::move(name), std::move(table)); + } + + void SetUp() override { + ASSERT_OK_AND_ASSIGN( + auto loc, flight::Location::ForGrpcTcp("localhost", server->port())); + ASSERT_OK_AND_ASSIGN(client_, flight::FlightClient::Connect(loc)); + } + + std::unique_ptr client_; + static std::unique_ptr server; +}; + +std::unique_ptr TestingArrowFlightServerTest::server; + +TEST_F(TestingArrowFlightServerTest, basicClientConnection) { + auto sampleTable = makeArrowTable( + {"id", "value"}, + {makeNumericArray({1, 2}), + makeNumericArray({41, 42})}); + updateTable("sample-data", sampleTable); + + ASSERT_RAISES(KeyError, client_->DoGet(flight::Ticket{"empty"})); + + auto emptyTable = makeArrowTable({}, {}); + updateTable("empty", emptyTable); + + ASSERT_RAISES(KeyError, client_->DoGet(flight::Ticket{"non-existent-table"})); + + ASSERT_OK_AND_ASSIGN(auto reader, client_->DoGet(flight::Ticket{"empty"})); + ASSERT_OK_AND_ASSIGN(auto actual, reader->ToTable()); + EXPECT_TRUE(actual->Equals(*emptyTable)); + + ASSERT_OK_AND_ASSIGN(reader, client_->DoGet(flight::Ticket{"sample-data"})); + ASSERT_OK_AND_ASSIGN(actual, reader->ToTable()); + EXPECT_TRUE(actual->Equals(*sampleTable)); + + server->removeTable("sample-data"); + ASSERT_RAISES(KeyError, client_->DoGet(flight::Ticket{"sample-data"})); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md new file mode 100644 index 0000000000000..3a5f2e5786c67 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md @@ -0,0 +1,7 @@ +### Placeholder TLS Certificates for Arrow Flight Connector Unit Testing +The `tls_certs` directory contains placeholder TLS certificates generated for unit testing the Arrow Flight Connector with TLS enabled. These certificates are not intended for production use and should only be used in the context of unit tests. + +### Generating TLS Certificates +To create the TLS certificates and keys inside the `tls_certs` folder, run the following command: + +`./generate_tls_certs.sh` diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh new file mode 100755 index 0000000000000..718f313c70a75 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Set directory for certificates and keys. +CERT_DIR="./tls_certs" +mkdir -p $CERT_DIR + +# Dummy values for the certificates. +COUNTRY="US" +STATE="State" +LOCALITY="City" +ORGANIZATION="MyOrg" +ORG_UNIT="MyUnit" +COMMON_NAME="MyCA" +SERVER_CN="server.mydomain.com" + +# Step 1: Generate CA private key and self-signed certificate. +openssl genpkey -algorithm RSA -out $CERT_DIR/ca.key +openssl req -key $CERT_DIR/ca.key -new -x509 -out $CERT_DIR/ca.crt -days 365000 \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$COMMON_NAME" + +# Step 2: Generate server private key. +openssl genpkey -algorithm RSA -out $CERT_DIR/server.key + +# Step 3: Generate server certificate signing request (CSR). +openssl req -new -key $CERT_DIR/server.key -out $CERT_DIR/server.csr \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$SERVER_CN" \ + -addext "subjectAltName=DNS:$COMMON_NAME,DNS:localhost" \ + +# Step 4: Sign server CSR with the CA certificate to generate the server certificate. +openssl x509 -req -in $CERT_DIR/server.csr -CA $CERT_DIR/ca.crt -CAkey $CERT_DIR/ca.key \ + -CAcreateserial -out $CERT_DIR/server.crt -days 365000 \ + -extfile <(printf "subjectAltName=DNS:$COMMON_NAME,DNS:localhost") + +# Step 5: Output the generated files. +echo "Certificate Authority (CA) certificate: $CERT_DIR/ca.crt" +echo "Server certificate: $CERT_DIR/server.crt" +echo "Server private key: $CERT_DIR/server.key" + +# Step 6: Remove unused files. +rm -rf $CERT_DIR/server.csr $CERT_DIR/ca.srl $CERT_DIR/ca.key diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt new file mode 100644 index 0000000000000..6740e89c54e17 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDmzCCAoOgAwIBAgIUf+rP48iL39yGlAfFQTIp5bmM4uQwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI0MTIwMzExMDQxMVoYDzMwMjQwNDA1MTEwNDExWjBcMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxDTALBgNVBAMMBE15Q0EwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCliiXIcSmxXAAq2k/XjcZniDgEDCxWKZGiV8JBiJwY +MMBJtqcVzWfiDpO2u6d1dfGb6utlRW+1dnwupzURCMmZff4bqlPx4ZejRXDrWzKz +08WSpDVZwC2H5XOllwK36Cn4gvPRe3YWVcdDGHy7GL+zsJENvawJj0BH952MU4bk +sV52zEkN291bfN9sSYfT1NCJuLPM0Qsf97DeQ+wHXEw+t4XVMF3FQbciQp0y6CnA +wfFFN14WDiWxukP1I3kuDYYA6h/WJCQMp5rU2NCB9nIQrulYRxFaepMYENLxgAyj +gFaoRh2Kt2k7XKv6WOa6CmYm2dZERPlbA+oNAHkaHw6lAgMBAAGjUzBRMB0GA1Ud +DgQWBBSN+3vRlXGjs6c+rN94qgEnkPLl3DAfBgNVHSMEGDAWgBSN+3vRlXGjs6c+ +rN94qgEnkPLl3DAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAb +L40Oe2b/2xdUSyjqXJceVxaeA291fCpnu1C1JliP0hNI3fu9jjZhXHQoYub/4mod +8lriEDIcOCCiUfmi404akpqQHuBmOHaKEOtaaQkezjPsYnUra+O2ssqUo2zto5bK +gR0LGsb+4AO0bDvq+QVI6kEQqAAIf6qC+kpg/jV4iKJ1J6Qw4R3QppYBm6SQcfvI +hfUfDSO6SNfy0f/ZVCavbJIP9zG/BfAD9DEERocw03PiN5bm4IXJ3HH8rxyuBfJ5 +Eg/fPP5TlZ2H7Kqb3VgVBGWJtNXWmJphHyraBJTEuxgXWvl6AaW0P/3dsJi3rfdD +zDIT7AmENLCom8Gl0bgM +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt new file mode 100644 index 0000000000000..92c91f2d613b0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIUUhmhZP94nIowrg2EarzfEBp6W1EwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI0MTIwMzExMDQxMVoYDzMwMjQwNDA1MTEwNDExWjBrMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxHDAaBgNVBAMME3NlcnZlci5teWRvbWFpbi5jb20w +ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDSxC4zCC4GFZbX+fdFgWbL +sj4PortyOM7mzRjNaQ3M0FTSEy5xET9C2qFlBCJ7AL7DlbSLmSckYY/FkdfMqNN2 ++NZ0Dy2d6bZN+ly5N/QBVnyS/5HVC3MXa6Y2BmFXiBnczWfGBwj+uVHlKOUWUyNi +EyUkhuPwtYXkFmJoqBxJSPC6cxX6NzMujnwCF18dUf0Vra44osu4moaovmg3c9jM +cBtmafFs9F54FoAEuLotjISVEa7VY6th5RxXJHpgas+0R5EBddGYKbTRiUYjht7r +pS+An0ey02oOjEWdqLnQSg/SUGKuRXULyE5l1A0HfNQtvepUQotb9ull1F7OrbfB +AgMBAAGjXjBcMBoGA1UdEQQTMBGCBE15Q0GCCWxvY2FsaG9zdDAdBgNVHQ4EFgQU +vnCLWjre4jqkKzC24psCPh1oIQwwHwYDVR0jBBgwFoAUjft70ZVxo7OnPqzfeKoB +J5Dy5dwwDQYJKoZIhvcNAQELBQADggEBAJCiJgtTw/7b/g3QnJfM4teQkFS440Ii +weqQJMoP6als8Fc3opPKv9eC5w0wqaLlIdwJjzGM5PmCAtGVafo22TbqhZyQdzQu +TUKv1DaVF0JBVAGVxTSDIK9r5Ww4mDAQnQENLC6soS3AvYDEi+8667YLoNNdhRCX +q2D5v76UN45idiShppxOw53whsvpHv+wyqcdse7DhgM9boCbx51Uvv3l/AEToyaj +S1xeIkBwNpSYU0ax2Lr1j2yoKbzAa3MHy8Php+T5CGji02+HwwlvlPDLtw8q5gHw +BLSwlAHgclPxUTWNNoCqjfX8Bi083+QDCLm0rgQ45xljNDbFAF1Y5hA= +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key new file mode 100644 index 0000000000000..2cdf5750a4753 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDSxC4zCC4GFZbX ++fdFgWbLsj4PortyOM7mzRjNaQ3M0FTSEy5xET9C2qFlBCJ7AL7DlbSLmSckYY/F +kdfMqNN2+NZ0Dy2d6bZN+ly5N/QBVnyS/5HVC3MXa6Y2BmFXiBnczWfGBwj+uVHl +KOUWUyNiEyUkhuPwtYXkFmJoqBxJSPC6cxX6NzMujnwCF18dUf0Vra44osu4moao +vmg3c9jMcBtmafFs9F54FoAEuLotjISVEa7VY6th5RxXJHpgas+0R5EBddGYKbTR +iUYjht7rpS+An0ey02oOjEWdqLnQSg/SUGKuRXULyE5l1A0HfNQtvepUQotb9ull +1F7OrbfBAgMBAAECggEAAxbZuuESGGAMMm9HLGhKHgbHU8gnv2Phdbrka+SYBYg5 +UYzTHLh3FwEsjd4VnaweJ7CN1WDb1NvWmTum/DCebJ1HKqtjKLAZfk8q2TLGmXdL +pzWOdQ8MX1fKP2sIlcl0kFbNCE8vprjneDyBLtqOK36eiAh/fl6BQ12QAMLjyv/L +OwXSY4ESs/RzxRzFgdT98cDZFL7y0FVIjJo/Q5lfW9UwwSfw8tOLNXKTYwPHqIfJ +NjfWD7IqztQlnanyRXv5dScp80i8p9qgH0i8YfVBHZDeOmHGLcltilLRZ0dQ/X0g +Lrr0aIO3iLhmTIkJRzUnGeyvDjxcPINvRSBBwXy04QKBgQDpFJa/EwSsWj8586oh +xgm0Z3q+FiEeCe7aLLPcXAS2EDvix5ibJDT2y1Aa/kXq25S53npa/7Ov6TJs5H4g +eyshDtR1wVhz+rIggREiX/sagkhwnNsssUZFv5t9PdnaFXpVnH49m5Qc8HO3owtN +t8EGSRcAQ4o/fLWLs51qd38cIQKBgQDnfd8YPyDQ03xDC/3+Qrypyc/xhGnCuj7w +ZeA5iEyTnnNxL0a0B6PWcSk2BZReMNQKgYtipnsOQKtwHMttxtXYs/VQpeB4KoWE +zEwW0fV3MMsXN+nVJlEZnVaTbmYXknjeZrh/rNjsY96yxw8NtvAuYSpnqtr3N2nd +iMQ3G/QnoQKBgGMi+bdNvIgeXpQkmrGAzTHpbaCaQv3G1cwAhYPts6dIomAj6znZ +nZl3ApxomI57VPf1s+8uoVvqASOl0Cu6l66Y4y8uzJOQBuGiZApN7rzouy0C2opY +4H3cMKOFgjqrNfxh8qP7n3TrpRxvgehNhxFIVzsqfwvf3EwOWp8lMnBhAoGAZ25E +Ge9K2ENGCCb5i3uCFFLJiF3ja1AQAxVhxBL0NBjd97pp2tJ3D79r7Gk9y4ABndgX +0TIVVV7ruqIC+r+WmMZ/W1NiIg7NrXIipSeWh3TTqUIgRk5iehFkt2biUrHtM2Gu +Gc2+9pAA1tw+C6CrW+2qJrueLksiEAulsAHba0ECgYBIgIiY+Gx+XecEgCwAhWcn +GzNDAAlA4IgBjHpUtIByflzQDqlECKXjPbVBKfyq6eLt40upFmQCLsn+AkiQau8A +3cFAK9wJOAHv9KuWDrbHyhRE9CrJ6BqsY2goC3LiFCTgJy1TrRl6CDaFzHivONwF +LNPflYk5s376UWqxC+HtIA== +-----END PRIVATE KEY----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp new file mode 100644 index 0000000000000..89b4ed01e6e89 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" + +using namespace arrow::flight; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +void ArrowFlightConnectorTestBase::SetUp() { + OperatorTestBase::SetUp(); + + if (!velox::connector::hasConnectorFactory( + presto::ArrowFlightConnectorFactory::kArrowFlightConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + } + velox::connector::registerConnector( + velox::connector::getConnectorFactory( + ArrowFlightConnectorFactory::kArrowFlightConnectorName) + ->newConnector(kFlightConnectorId, config_)); +} + +void ArrowFlightConnectorTestBase::TearDown() { + velox::connector::unregisterConnector(kFlightConnectorId); + OperatorTestBase::TearDown(); +} + +void FlightWithServerTestBase::SetUp() { + ArrowFlightConnectorTestBase::SetUp(); + + FlightServerOptions serverOptions(getServerLocation()); + server_ = std::make_unique(); + setFlightServerOptions(&serverOptions); + ASSERT_OK(server_->Init(serverOptions)); +} + +void FlightWithServerTestBase::TearDown() { + ASSERT_OK(server_->Shutdown()); + ArrowFlightConnectorTestBase::TearDown(); +} + +Location FlightWithServerTestBase::getServerLocation() { + AFC_ASSIGN_OR_RAISE(auto loc, Location::ForGrpcTcp(BIND_HOST, LISTEN_PORT)); + return loc; +} + +std::vector> +FlightWithServerTestBase::makeSplits( + const std::initializer_list& tickets, + const std::vector& locations) { + std::vector> splits; + splits.reserve(tickets.size()); + for (auto& ticket : tickets) { + FlightEndpoint flightEndpoint; + flightEndpoint.ticket.ticket = ticket; + flightEndpoint.locations = locations; + AFC_ASSIGN_OR_RAISE( + auto flightEndpointStr, flightEndpoint.SerializeToString()); + auto flightEndpointBytes = folly::base64Encode(flightEndpointStr); + splits.push_back(std::make_shared( + kFlightConnectorId, flightEndpointBytes)); + } + return splits; +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h new file mode 100644 index 0000000000000..7f9575a2daf78 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h" +#include "velox/common/config/Config.h" +#include "velox/connectors/Connector.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::test { + +static const std::string kFlightConnectorId = "test-flight"; + +class ArrowFlightConnectorTestBase + : public velox::exec::test::OperatorTestBase { + public: + void SetUp() override; + + void TearDown() override; + + protected: + explicit ArrowFlightConnectorTestBase( + std::shared_ptr config) + : config_{std::move(config)} {} + + ArrowFlightConnectorTestBase() + : config_{std::make_shared( + std::move(std::unordered_map{}))} {} + + protected: + std::shared_ptr config_; +}; + +/// Creates and registers an Arrow Flight connector and +/// spawns a Flight server for testing. +/// Initially there is no data in the Flight server, +/// tests should call FlightWithServerTestBase::updateTables to populate it. +class FlightWithServerTestBase : public ArrowFlightConnectorTestBase { + public: + static constexpr const char* BIND_HOST = "127.0.0.1"; + static constexpr const char* CONNECT_HOST = "localhost"; + constexpr static int LISTEN_PORT = 5000; + + void SetUp() override; + + void TearDown() override; + + /// Convenience method which creates splits for the test flight server + static std::vector> + makeSplits( + const std::initializer_list& tokens, + const std::vector& locations = + std::vector{ + *arrow::flight::Location::ForGrpcTcp(CONNECT_HOST, LISTEN_PORT)}); + + /// Add (or update) a table in the test flight server + void updateTable(std::string name, std::shared_ptr table) { + server_->updateTable(std::move(name), std::move(table)); + } + + virtual arrow::flight::Location getServerLocation(); + + virtual void setFlightServerOptions( + arrow::flight::FlightServerOptions* serverOptions) {} + + protected: + explicit FlightWithServerTestBase( + std::shared_ptr config) + : ArrowFlightConnectorTestBase{std::move(config)} {} + + FlightWithServerTestBase() : ArrowFlightConnectorTestBase() {} + + private: + std::unique_ptr server_; +}; + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp new file mode 100644 index 0000000000000..0a1e584a677c6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" + +namespace facebook::presto::test { +namespace { +static const std::string kFlightConnectorId = "test-flight"; +} + +velox::exec::test::PlanBuilder& ArrowFlightPlanBuilder::flightTableScan( + const velox::RowTypePtr& outputType, + std::unordered_map< + std::string, + std::shared_ptr> assignments, + bool createDefaultColumnHandles) { + if (createDefaultColumnHandles) { + for (const auto& name : outputType->names()) { + // Provide unaliased defaults for unmapped columns. + // `emplace` won't modify the map if the key already exists, + // so existing aliases are kept. + assignments.emplace( + name, std::make_shared(name)); + } + } + + return startTableScan() + .tableHandle(std::make_shared(kFlightConnectorId)) + .outputType(outputType) + .assignments(std::move(assignments)) + .endTableScan(); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h new file mode 100644 index 0000000000000..5eda2c60aac16 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::presto::test { + +class ArrowFlightPlanBuilder : public velox::exec::test::PlanBuilder { + public: + /// @brief Add a table scan node to the Plan, using the Flight connector + /// @param outputType The output type of the table scan node + /// @param assignments mapping from the column aliases to real column handles + /// @param createDefaultColumnHandles If true, generate column handles for + /// for the columns which don't have an entry in assignments + velox::exec::test::PlanBuilder& flightTableScan( + const velox::RowTypePtr& outputType, + std::unordered_map< + std::string, + std::shared_ptr> assignments = {}, + bool createDefaultColumnHandles = true); +}; + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt new file mode 100644 index 0000000000000..b6d2337a2d301 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt @@ -0,0 +1,19 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library( + presto_flight_connector_test_lib + TestingArrowFlightServer.cpp ArrowFlightConnectorTestBase.cpp Utils.cpp + ArrowFlightPlanBuilder.cpp) + +target_link_libraries( + presto_flight_connector_test_lib arrow presto_flight_connector + velox_exception presto_flight_connector_utils velox_exec_test_lib) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.cpp new file mode 100644 index 0000000000000..ad46bcbc403aa --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.cpp @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h" + +using namespace arrow::flight; + +namespace facebook::presto::test { + +arrow::Status TestingArrowFlightServer::DoGet( + const ServerCallContext& context, + const Ticket& request, + std::unique_ptr* stream) { + auto it = tables_.find(request.ticket); + if (it == tables_.end()) { + return arrow::Status::KeyError("requested table does not exist"); + } + auto& table = it->second; + auto reader = std::make_shared(table); + *stream = std::make_unique(std::move(reader)); + return arrow::Status::OK(); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h new file mode 100644 index 0000000000000..f634e437b672c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace facebook::presto::test { + +/// Test Flight server which supports DoGet operations. +/// Maintains a list of named arrow tables, +/// +/// Normally, the tickets would be obtained by calling GetFlightInfo, +/// but since this is done by the coordinator this part is omitted. +/// Instead, the ticket is simply the name of the table to fetch. +class TestingArrowFlightServer : public arrow::flight::FlightServerBase { + public: + TestingArrowFlightServer() = default; + + void updateTable(std::string name, std::shared_ptr table) { + tables_.emplace(std::move(name), std::move(table)); + } + + void removeTable(const std::string& name) { + tables_.erase(name); + } + + arrow::Status DoGet( + const arrow::flight::ServerCallContext& context, + const arrow::flight::Ticket& request, + std::unique_ptr* stream) override; + + private: + std::unordered_map> tables_; +}; + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp new file mode 100644 index 0000000000000..8b44c4995fbed --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Utils.h" +#include +#include + +namespace facebook::presto::test { + +std::shared_ptr makeDecimalArray( + const std::vector& decimalValues, + int precision, + int scale) { + auto decimalType = arrow::decimal(precision, scale); + auto builder = + arrow::Decimal128Builder(decimalType, arrow::default_memory_pool()); + + for (const auto& value : decimalValues) { + arrow::Decimal128 dec(value); + AFC_RAISE_NOT_OK(builder.Append(dec)); + } + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeTimestampArray( + const std::vector& values, + arrow::TimeUnit::type timeUnit, + arrow::MemoryPool* memory_pool) { + arrow::TimestampBuilder builder(arrow::timestamp(timeUnit), memory_pool); + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeStringArray( + const std::vector& values) { + auto builder = arrow::StringBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeBooleanArray( + const std::vector& values) { + auto builder = arrow::BooleanBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +auto makeRecordBatch( + const std::vector& names, + const arrow::ArrayVector& arrays) { + VELOX_CHECK_EQ(names.size(), arrays.size()); + + auto nrows = (!arrays.empty()) ? (arrays[0]->length()) : 0; + arrow::FieldVector fields{}; + for (int i = 0; i < arrays.size(); i++) { + VELOX_CHECK_EQ(arrays[i]->length(), nrows); + fields.push_back( + std::make_shared(names[i], arrays[i]->type())); + } + + auto schema = arrow::schema(fields); + return arrow::RecordBatch::Make(schema, nrows, arrays); +} + +std::shared_ptr makeArrowTable( + const std::vector& names, + const arrow::ArrayVector& arrays) { + AFC_RETURN_OR_RAISE( + arrow::Table::FromRecordBatches({makeRecordBatch(names, arrays)})); +} + +std::string readFile(const std::string& path) { + std::ifstream file(path); + VELOX_CHECK( + file.is_open(), "Could not open file \"{}\": {}", path, strerror(errno)); + return { + std::istreambuf_iterator(file), std::istreambuf_iterator()}; +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h new file mode 100644 index 0000000000000..c3bdcc0df1af3 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto::test { + +template +auto makeNumericArray(const std::vector& values) { + auto builder = arrow::NumericBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeDecimalArray( + const std::vector& decimalValues, + int precision, + int scale); + +std::shared_ptr makeTimestampArray( + const std::vector& values, + arrow::TimeUnit::type timeUnit, + arrow::MemoryPool* memory_pool = arrow::default_memory_pool()); + +std::shared_ptr makeStringArray( + const std::vector& values); + +std::shared_ptr makeBooleanArray(const std::vector& values); + +auto makeRecordBatch( + const std::vector& names, + const arrow::ArrayVector& arrays); + +std::shared_ptr makeArrowTable( + const std::vector& names, + const arrow::ArrayVector& arrays); + +std::string readFile(const std::string& path); + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt index 0547dcaeeafcb..9e1b6187eef4a 100644 --- a/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt @@ -20,6 +20,7 @@ add_test(presto_operators_test presto_operators_test) target_link_libraries( presto_operators_test + presto_connector presto_operators_plan_builder presto_operators presto_protocol diff --git a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp index 9d05cf2212194..4c18a3a074a6a 100644 --- a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp @@ -18,9 +18,9 @@ #include "folly/experimental/EventCount.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/TaskResource.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "presto_cpp/main/tests/HttpServerWrapper.h" #include "presto_cpp/main/tests/MultableConfigs.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" #include "velox/common/base/Fs.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index 5841728512238..e22e2d8ddbd6b 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -15,14 +15,14 @@ add_library(presto_type_converter OBJECT TypeParser.cpp) target_link_libraries(presto_type_converter velox_type_parser) add_library( - presto_types OBJECT - PrestoToVeloxQueryPlan.cpp PrestoToVeloxExpr.cpp VeloxPlanValidator.cpp - PrestoToVeloxSplit.cpp PrestoToVeloxConnector.cpp) + presto_types OBJECT PrestoToVeloxQueryPlan.cpp PrestoToVeloxExpr.cpp + VeloxPlanValidator.cpp PrestoToVeloxSplit.cpp) add_dependencies(presto_types presto_operators presto_type_converter velox_type velox_type_fbhive) -target_link_libraries(presto_types presto_type_converter velox_type_fbhive - velox_hive_partition_function velox_tpch_gen velox_functions_json) +target_link_libraries( + presto_types presto_type_converter velox_type_fbhive + velox_hive_partition_function velox_tpch_gen velox_functions_json) set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index cdde0813df364..f2cd332682707 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -14,7 +14,7 @@ // clang-format off #include "presto_cpp/main/common/Configs.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" #include #include "velox/core/QueryCtx.h" diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp index 1d11be2e904fc..6ecda241f1478 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp @@ -12,7 +12,7 @@ * limitations under the License. */ #include "presto_cpp/main/types/PrestoToVeloxSplit.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "velox/exec/Exchange.h" using namespace facebook::velox; diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt index 28f73aff40b80..4ef64f2412fb4 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt @@ -20,6 +20,7 @@ target_link_libraries( presto_velox_split_test GTest::gtest GTest::gtest_main + presto_connector presto_operators presto_protocol velox_dwio_common @@ -48,6 +49,7 @@ target_link_libraries( presto_expressions_test GTest::gtest GTest::gtest_main + presto_connector $ $ $ @@ -86,6 +88,7 @@ add_test( target_link_libraries( presto_to_velox_connector_test + presto_connector presto_protocol presto_operators presto_type_converter @@ -123,6 +126,7 @@ add_test( target_link_libraries( presto_to_velox_query_plan_test + presto_connector presto_operators presto_protocol presto_type_converter diff --git a/presto-native-execution/presto_cpp/main/types/tests/PlanConverterTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/PlanConverterTest.cpp index 715780befa84a..3f9a402ba8809 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/PlanConverterTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/PlanConverterTest.cpp @@ -14,11 +14,11 @@ #include #include "presto_cpp/main/common/tests/test_json.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "presto_cpp/main/operators/LocalPersistentShuffle.h" #include "presto_cpp/main/operators/PartitionAndSerialize.h" #include "presto_cpp/main/operators/ShuffleRead.h" #include "presto_cpp/main/operators/ShuffleWrite.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" #include "presto_cpp/main/types/tests/TestUtils.h" #include "velox/connectors/hive/TableHandle.h" diff --git a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxConnectorTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxConnectorTest.cpp index 932f48a611f73..a88b235686498 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxConnectorTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxConnectorTest.cpp @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include #include "velox/common/base/tests/GTestUtils.h" diff --git a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxSplitTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxSplitTest.cpp index 5522684b262d2..9bcbf9b4f0542 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxSplitTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxSplitTest.cpp @@ -13,7 +13,7 @@ */ #include "presto_cpp/main/types/PrestoToVeloxSplit.h" #include -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" using namespace facebook::velox; diff --git a/presto-native-execution/presto_cpp/presto_protocol/Makefile b/presto-native-execution/presto_cpp/presto_protocol/Makefile index 3ee2b4e802b81..09b43df28b4f5 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/Makefile +++ b/presto-native-execution/presto_cpp/presto_protocol/Makefile @@ -45,14 +45,23 @@ presto_protocol-cpp: presto_protocol-json chevron -d connector/tpch/presto_protocol_tpch.json connector/tpch/presto_protocol-json-hpp.mustache >> connector/tpch/presto_protocol_tpch.h clang-format -style=file -i connector/tpch/presto_protocol_tpch.h connector/tpch/presto_protocol_tpch.cpp + # build arrow_flight connector related structs + echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.cpp + chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-cpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.cpp + echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.h + chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-hpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.h + clang-format -style=file -i connector/arrow_flight/presto_protocol_arrow_flight.h connector/arrow_flight/presto_protocol_arrow_flight.cpp + presto_protocol-json: ./java-to-struct-json.py --config core/presto_protocol_core.yml core/special/*.java core/special/*.inc -j | jq . > core/presto_protocol_core.json ./java-to-struct-json.py --config connector/hive/presto_protocol_hive.yml connector/hive/special/*.inc -j | jq . > connector/hive/presto_protocol_hive.json ./java-to-struct-json.py --config connector/iceberg/presto_protocol_iceberg.yml connector/iceberg/special/*.inc -j | jq . > connector/iceberg/presto_protocol_iceberg.json ./java-to-struct-json.py --config connector/tpch/presto_protocol_tpch.yml connector/tpch/special/*.inc -j | jq . > connector/tpch/presto_protocol_tpch.json + ./java-to-struct-json.py --config connector/arrow_flight/presto_protocol_arrow_flight.yml connector/arrow_flight/special/*.inc -j | jq . > connector/arrow_flight/presto_protocol_arrow_flight.json presto_protocol.proto: presto_protocol-json pystache presto_protocol-protobuf.mustache core/presto_protocol_core.json > core/presto_protocol_core.proto pystache presto_protocol-protobuf.mustache connector/hive/presto_protocol_hive.json > connector/hive/presto_protocol_hive.proto pystache presto_protocol-protobuf.mustache connector/iceberg/presto_protocol_iceberg.json > connector/iceberg/presto_protocol_iceberg.proto pystache presto_protocol-protobuf.mustache connector/tpch/presto_protocol_tpch.json > connector/tpch/presto_protocol_tpch.proto + pystache presto_protocol-protobuf.mustache connector/arrow_flight/presto_protocol_arrow_flight.json > connector/arrow_flight/presto_protocol_arrow_flight.proto diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h new file mode 100644 index 0000000000000..95cda16115695 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +#include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" + +namespace facebook::presto::protocol::arrow_flight { +using ArrowConnectorProtocol = ConnectorProtocolTemplate< + ArrowTableHandle, + ArrowTableLayoutHandle, + ArrowColumnHandle, + NotImplemented, + NotImplemented, + ArrowSplit, + NotImplemented, + ArrowTransactionHandle, + NotImplemented>; +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache new file mode 100644 index 0000000000000..b6ecb68507285 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// presto_protocol.prolog.cpp +// + +{{#.}} +{{#comment}} +{{comment}} +{{/comment}} +{{/.}} + +#include + +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +using namespace std::string_literals; + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight +{{#.}} +{{#cinc}} +{{&cinc}} +{{/cinc}} +{{^cinc}} +{{#struct}} +namespace facebook::presto::protocol::arrow_flight { + {{#super_class}} + {{&class_name}}::{{&class_name}}() noexcept { + _type = "{{json_key}}"; + } + {{/super_class}} + + void to_json(json& j, const {{&class_name}}& p) { + j = json::object(); + {{#super_class}} + j["@type"] = "{{&json_key}}"; + {{/super_class}} + {{#fields}} + to_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); + {{/fields}} + } + + void from_json(const json& j, {{&class_name}}& p) { + {{#super_class}} + p._type = j["@type"]; + {{/super_class}} + {{#fields}} + from_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); + {{/fields}} + } +} +{{/struct}} +{{#enum}} +namespace facebook::presto::protocol::arrow_flight { + //Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + + // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays + static const std::pair<{{&class_name}}, json> + {{&class_name}}_enum_table[] = { // NOLINT: cert-err58-cpp + {{#elements}} + { {{&class_name}}::{{&element}}, "{{&element}}" }{{^_last}},{{/_last}} + {{/elements}} + }; + void to_json(json& j, const {{&class_name}}& e) + { + static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); + const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), + [e](const std::pair<{{&class_name}}, json>& ej_pair) -> bool + { + return ej_pair.first == e; + }); + j = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->second; + } + void from_json(const json& j, {{&class_name}}& e) + { + static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); + const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), + [&j](const std::pair<{{&class_name}}, json>& ej_pair) -> bool + { + return ej_pair.second == j; + }); + e = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->first; + } +} +{{/enum}} +{{#abstract}} +namespace facebook::presto::protocol::arrow_flight { + void to_json(json& j, const std::shared_ptr<{{&class_name}}>& p) { + if ( p == nullptr ) { + return; + } + String type = p->_type; + + {{#subclasses}} + if ( type == "{{&key}}" ) { + j = *std::static_pointer_cast<{{&type}}>(p); + return; + } + {{/subclasses}} + + throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); + } + + void from_json(const json& j, std::shared_ptr<{{&class_name}}>& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error &e) { + throw ParseError(std::string(e.what()) + " {{&class_name}} {{&key}} {{&class_name}}"); + } + + {{#subclasses}} + if ( type == "{{&key}}" ) { + std::shared_ptr<{{&type}}> k = std::make_shared<{{&type}}>(); + j.get_to(*k); + p = std::static_pointer_cast<{{&class_name}}>(k); + return; + } + {{/subclasses}} + + throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); + } +} +{{/abstract}} +{{/cinc}} +{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache new file mode 100644 index 0000000000000..be08bd9e491c2 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +{{#.}} +{{#comment}} +{{comment}} +{{/comment}} +{{/.}} + +#include +#include +#include +#include + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight +{{#.}} +{{#hinc}} +{{&hinc}} +{{/hinc}} +{{^hinc}} +{{#struct}} +namespace facebook::presto::protocol::arrow_flight { + struct {{class_name}} {{#super_class}}: public {{super_class}}{{/super_class}}{ + {{#fields}} + {{#field_local}}{{#optional}}std::shared_ptr<{{/optional}}{{&field_text}}{{#optional}}>{{/optional}} {{&field_name}} = {};{{/field_local}} + {{/fields}} + + {{#super_class}} + {{class_name}}() noexcept; + {{/super_class}} + }; + void to_json(json& j, const {{class_name}}& p); + void from_json(const json& j, {{class_name}}& p); +} +{{/struct}} +{{#enum}} +namespace facebook::presto::protocol::arrow_flight { + enum class {{class_name}} { + {{#elements}} + {{&element}}{{^_last}},{{/_last}} + {{/elements}} + }; + extern void to_json(json& j, const {{class_name}}& e); + extern void from_json(const json& j, {{class_name}}& e); +} +{{/enum}} +{{/hinc}} +{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp new file mode 100644 index 0000000000000..e5b5cf2f9ae3b --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp @@ -0,0 +1,215 @@ +// DO NOT EDIT : This file is generated by chevron +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// presto_protocol.prolog.cpp +// + +// This file is generated DO NOT EDIT @generated + +#include + +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +using namespace std::string_literals; + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowColumnHandle::ArrowColumnHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowColumnHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + to_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} + +void from_json(const json& j, ArrowColumnHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + from_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowSplit::ArrowSplit() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowSplit& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, "schemaName", p.schemaName, "ArrowSplit", "String", "schemaName"); + to_json_key(j, "tableName", p.tableName, "ArrowSplit", "String", "tableName"); + to_json_key( + j, + "flightEndpointBytes", + p.flightEndpointBytes, + "ArrowSplit", + "String", + "flightEndpointBytes"); +} + +void from_json(const json& j, ArrowSplit& p) { + p._type = j["@type"]; + from_json_key( + j, "schemaName", p.schemaName, "ArrowSplit", "String", "schemaName"); + from_json_key( + j, "tableName", p.tableName, "ArrowSplit", "String", "tableName"); + from_json_key( + j, + "flightEndpointBytes", + p.flightEndpointBytes, + "ArrowSplit", + "String", + "flightEndpointBytes"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowTableHandle::ArrowTableHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowTableHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key(j, "schema", p.schema, "ArrowTableHandle", "String", "schema"); + to_json_key(j, "table", p.table, "ArrowTableHandle", "String", "table"); +} + +void from_json(const json& j, ArrowTableHandle& p) { + p._type = j["@type"]; + from_json_key(j, "schema", p.schema, "ArrowTableHandle", "String", "schema"); + from_json_key(j, "table", p.table, "ArrowTableHandle", "String", "table"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + + if (type == "arrow-flight") { + j = *std::static_pointer_cast(p); + return; + } + + throw TypeError(type + " no abstract type ColumnHandle "); +} + +void from_json(const json& j, std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError(std::string(e.what()) + " ColumnHandle ColumnHandle"); + } + + if (type == "arrow-flight") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + + throw TypeError(type + " no abstract type ColumnHandle "); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowTableLayoutHandle::ArrowTableLayoutHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowTableLayoutHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "table", + p.table, + "ArrowTableLayoutHandle", + "ArrowTableHandle", + "table"); + to_json_key( + j, + "columnHandles", + p.columnHandles, + "ArrowTableLayoutHandle", + "List", + "columnHandles"); + to_json_key( + j, + "tupleDomain", + p.tupleDomain, + "ArrowTableLayoutHandle", + "TupleDomain>", + "tupleDomain"); +} + +void from_json(const json& j, ArrowTableLayoutHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "table", + p.table, + "ArrowTableLayoutHandle", + "ArrowTableHandle", + "table"); + from_json_key( + j, + "columnHandles", + p.columnHandles, + "ArrowTableLayoutHandle", + "List", + "columnHandles"); + from_json_key( + j, + "tupleDomain", + p.tupleDomain, + "ArrowTableLayoutHandle", + "TupleDomain>", + "tupleDomain"); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h new file mode 100644 index 0000000000000..2a9cb81d00b47 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h @@ -0,0 +1,82 @@ +// DO NOT EDIT : This file is generated by chevron +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +// This file is generated DO NOT EDIT @generated + +#include +#include +#include +#include + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowColumnHandle : public ColumnHandle { + String columnName = {}; + Type columnType = {}; + + ArrowColumnHandle() noexcept; +}; +void to_json(json& j, const ArrowColumnHandle& p); +void from_json(const json& j, ArrowColumnHandle& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowSplit : public ConnectorSplit { + String schemaName = {}; + String tableName = {}; + String flightEndpointBytes = {}; + + ArrowSplit() noexcept; +}; +void to_json(json& j, const ArrowSplit& p); +void from_json(const json& j, ArrowSplit& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowTableHandle : public ConnectorTableHandle { + String schema = {}; + String table = {}; + + ArrowTableHandle() noexcept; +}; +void to_json(json& j, const ArrowTableHandle& p); +void from_json(const json& j, ArrowTableHandle& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowTableLayoutHandle : public ConnectorTableLayoutHandle { + ArrowTableHandle table = {}; + List columnHandles = {}; + TupleDomain> tupleDomain = {}; + + ArrowTableLayoutHandle() noexcept; +}; +void to_json(json& j, const ArrowTableLayoutHandle& p); +void from_json(const json& j, ArrowTableLayoutHandle& p); +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml new file mode 100644 index 0000000000000..f34f6068eb777 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml @@ -0,0 +1,40 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +AbstractClasses: + ColumnHandle: + super: JsonEncodedSubclass + comparable: true + subclasses: + - { name: ArrowColumnHandle, key: arrow-flight } + + ConnectorTableHandle: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowTableHandle, key: arrow-flight } + + ConnectorTableLayoutHandle: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowTableLayoutHandle, key: arrow-flight } + + ConnectorSplit: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowSplit, key: arrow-flight } + +JavaClasses: + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc new file mode 100644 index 0000000000000..a93325f5b154a --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc new file mode 100644 index 0000000000000..dc573ca2e68cf --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index b0ecb4a22354a..e13e2eee8b750 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -1080,6 +1080,7 @@ void from_json(const json& j, std::shared_ptr& p) { */ // dependency TpchTransactionHandle +// dependency ArrowTransactionHandle namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index 32feb37fd501e..3a50fc7eac205 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -54,6 +54,7 @@ AbstractClasses: - { name: IcebergColumnHandle, key: hive-iceberg } - { name: TpchColumnHandle, key: tpch } - { name: SystemColumnHandle, key: $system@system } + - { name: ArrowColumnHandle, key: arrow-flight } ConnectorPartitioningHandle: super: JsonEncodedSubclass @@ -69,6 +70,7 @@ AbstractClasses: - { name: IcebergTableHandle, key: hive-iceberg } - { name: TpchTableHandle, key: tpch } - { name: SystemTableHandle, key: $system@system } + - { name: ArrowTableHandle, key: arrow-flight } ConnectorOutputTableHandle: super: JsonEncodedSubclass @@ -96,6 +98,7 @@ AbstractClasses: - { name: IcebergTableLayoutHandle, key: hive-iceberg } - { name: TpchTableLayoutHandle, key: tpch } - { name: SystemTableLayoutHandle, key: $system@system } + - { name: ArrowTableLayoutHandle, key: arrow-flight } ConnectorMetadataUpdateHandle: super: JsonEncodedSubclass @@ -111,6 +114,7 @@ AbstractClasses: - { name: RemoteSplit, key: $remote } - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } + - { name: ArrowSplit, key: arrow-flight } ConnectorHistogram: super: JsonEncodedSubclass diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc index 8ec2a94e84bd9..1dfb17e4a908f 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc @@ -13,6 +13,7 @@ */ // dependency TpchTransactionHandle +// dependency ArrowTransactionHandle namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp index c15084817a434..24f24f27f87a3 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp @@ -15,6 +15,7 @@ // DEPRECATED: This file is deprecated and will be removed in future versions. +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp" #include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.cpp" diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h index dd94975e3760d..c43ec92629f44 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h @@ -16,6 +16,7 @@ // DEPRECATED: This file is deprecated and will be removed in future versions. +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.h" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h" #include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h" diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml index 2cda40445c76e..c4cf2130b1749 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml @@ -53,6 +53,7 @@ AbstractClasses: - { name: IcebergColumnHandle, key: hive-iceberg } - { name: TpchColumnHandle, key: tpch } - { name: SystemColumnHandle, key: $system@system } + - { name: ArrowColumnHandle, key: arrow-flight } ConnectorPartitioningHandle: super: JsonEncodedSubclass @@ -68,6 +69,7 @@ AbstractClasses: - { name: IcebergTableHandle, key: hive-iceberg } - { name: TpchTableHandle, key: tpch } - { name: SystemTableHandle, key: $system@system } + - { name: ArrowTableHandle, key: arrow-flight } ConnectorOutputTableHandle: super: JsonEncodedSubclass @@ -95,6 +97,7 @@ AbstractClasses: - { name: IcebergTableLayoutHandle, key: hive-iceberg } - { name: TpchTableLayoutHandle, key: tpch } - { name: SystemTableLayoutHandle, key: $system@system } + - { name: ArrowTableLayoutHandle, key: arrow-flight } ConnectorMetadataUpdateHandle: super: JsonEncodedSubclass @@ -110,6 +113,7 @@ AbstractClasses: - { name: RemoteSplit, key: $remote } - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } + - { name: ArrowSplit, key: arrow-flight } ConnectorHistogram: super: JsonEncodedSubclass @@ -366,3 +370,7 @@ JavaClasses: - presto-main/src/main/java/com/facebook/presto/connector/system/SystemTransactionHandle.java - presto-spi/src/main/java/com/facebook/presto/spi/function/AggregationFunctionMetadata.java - presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java diff --git a/presto-native-execution/scripts/setup-adapters.sh b/presto-native-execution/scripts/setup-adapters.sh index 6c36424ebf90c..3cb965fe71781 100755 --- a/presto-native-execution/scripts/setup-adapters.sh +++ b/presto-native-execution/scripts/setup-adapters.sh @@ -35,15 +35,75 @@ function install_prometheus_cpp { cmake_install -DBUILD_SHARED_LIBS=ON -DENABLE_PUSH=OFF -DENABLE_COMPRESSION=OFF } +function install_abseil { + # abseil-cpp + github_checkout abseil/abseil-cpp 20240116.2 --depth 1 + cmake_install \ + -DABSL_BUILD_TESTING=OFF \ + -DCMAKE_CXX_STANDARD=17 \ + -DABSL_PROPAGATE_CXX_STD=ON \ + -DABSL_ENABLE_INSTALL=ON +} + +function install_grpc { + # grpc + github_checkout grpc/grpc v1.48.1 --depth 1 + cmake_install \ + -DgRPC_BUILD_TESTS=OFF \ + -DgRPC_ABSL_PROVIDER=package \ + -DgRPC_ZLIB_PROVIDER=package \ + -DgRPC_CARES_PROVIDER=package \ + -DgRPC_RE2_PROVIDER=package \ + -DgRPC_SSL_PROVIDER=package \ + -DgRPC_PROTOBUF_PROVIDER=package \ + -DgRPC_INSTALL=ON +} + +function install_arrow_flight { + ARROW_VERSION="${ARROW_VERSION:-15.0.0}" + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + export INSTALL_PREFIX=${INSTALL_PREFIX:-"/usr/local"} + LINUX_DISTRIBUTION=$(. /etc/os-release && echo ${ID}) + if [[ "$LINUX_DISTRIBUTION" == "ubuntu" || "$LINUX_DISTRIBUTION" == "debian" ]]; then + SUDO="${SUDO:-"sudo --preserve-env"}" + ${SUDO} apt install -y libc-ares-dev + ${SUDO} ldconfig -v 2>/dev/null | grep "${INSTALL_PREFIX}/lib" || \ + echo "${INSTALL_PREFIX}/lib" | ${SUDO} tee /etc/ld.so.conf.d/local-libraries.conf > /dev/null \ + && ${SUDO} ldconfig + else + dnf -y install c-ares-devel + ldconfig -v 2>/dev/null | grep "${INSTALL_PREFIX}/lib" || \ + echo "${INSTALL_PREFIX}/lib" | tee /etc/ld.so.conf.d/local-libraries.conf > /dev/null \ + && ldconfig + fi + else + # The installation script for the Arrow Flight connector currently works only on Linux distributions. + return 0 + fi + + install_abseil + install_grpc + + # NOTE: benchmarks are on due to a compilation error with v15.0.0, once updated that can be removed + # see https://github.com/apache/arrow/issues/41617 + wget_and_untar https://github.com/apache/arrow/archive/apache-arrow-${ARROW_VERSION}.tar.gz arrow + cmake_install_dir arrow/cpp \ + -DARROW_FLIGHT=ON \ + -DARROW_BUILD_BENCHMARKS=ON \ + -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} +} + cd "${DEPENDENCY_DIR}" || exit install_jwt=0 install_prometheus_cpp=0 +install_arrow_flight=0 if [ "$#" -eq 0 ]; then # Install all adapters by default install_jwt=1 install_prometheus_cpp=1 + install_arrow_flight=1 fi while [[ $# -gt 0 ]]; do @@ -56,6 +116,10 @@ while [[ $# -gt 0 ]]; do install_prometheus_cpp=1; shift ;; + arrow_flight) + install_arrow_flight=1; + shift + ;; *) echo "ERROR: Unknown option $1! will be ignored!" shift @@ -72,6 +136,10 @@ if [ $install_prometheus_cpp -eq 1 ]; then install_prometheus_cpp fi +if [ $install_arrow_flight -eq 1 ]; then + install_arrow_flight +fi + _ret=$? if [ $_ret -eq 0 ] ; then echo "All deps for Presto adapters installed!"