From 1598782d4ffd3f7a961f379148f17f34e16caf2b Mon Sep 17 00:00:00 2001 From: Steve Lord <72518652+stevelorddremio@users.noreply.github.com> Date: Mon, 3 Jun 2024 12:57:42 -0700 Subject: [PATCH] GH-41262: [Java][FlightSQL] Implement stateless prepared statements (#41237) ### Rationale for this change Expand the number of implemented languages for stateless prepared statements to include Java. ### What changes are included in this PR? Update FlightSqlClient and include a stateless server implementation example with tests. ### Are these changes tested? Yes, tests are added to cover a stateless server implementation. ### Are there any user-facing changes? There is a modified FlightSqlClient that is required to enable use of stateless prepared statements. * GitHub Issue: #41262 Lead-authored-by: Steve Lord Co-authored-by: Mateusz Rzeszutek Signed-off-by: David Li --- .../arrow/flight/sql/FlightSqlClient.java | 27 +- .../DoPutPreparedStatementResultPOJO.java | 38 +++ .../flight/sql/example/FlightSqlExample.java | 60 +++-- .../example/FlightSqlStatelessExample.java | 238 ++++++++++++++++++ .../arrow/flight/sql/test/TestFlightSql.java | 63 +++-- .../sql/test/TestFlightSqlStateless.java | 99 ++++++++ 6 files changed, 474 insertions(+), 51 deletions(-) create mode 100644 java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/DoPutPreparedStatementResultPOJO.java create mode 100644 java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java create mode 100644 java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index 6fe31fae9216b..a94dc563cfbcc 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -78,6 +78,7 @@ import org.apache.arrow.flight.SetSessionOptionsResult; import org.apache.arrow.flight.SyncPutListener; import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; import org.apache.arrow.flight.sql.util.TableRef; @@ -1048,15 +1049,35 @@ private Schema deserializeSchema(final ByteString bytes) { public FlightInfo execute(final CallOption... options) { checkOpen(); - final FlightDescriptor descriptor = FlightDescriptor + FlightDescriptor descriptor = FlightDescriptor .command(Any.pack(CommandPreparedStatementQuery.newBuilder() .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) .build()) .toByteArray()); if (parameterBindingRoot != null && parameterBindingRoot.getRowCount() > 0) { - try (final SyncPutListener listener = putParameters(descriptor, options)) { - listener.getResult(); + try (final SyncPutListener putListener = putParameters(descriptor, options)) { + if (getParameterSchema().getFields().size() > 0 && + parameterBindingRoot != null && + parameterBindingRoot.getRowCount() > 0) { + final PutResult read = putListener.read(); + if (read != null) { + try (final ArrowBuf metadata = read.getApplicationMetadata()) { + final FlightSql.DoPutPreparedStatementResult doPutPreparedStatementResult = + FlightSql.DoPutPreparedStatementResult.parseFrom(metadata.nioBuffer()); + descriptor = FlightDescriptor + .command(Any.pack(CommandPreparedStatementQuery.newBuilder() + .setPreparedStatementHandle( + doPutPreparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); + } + } + } + } catch (final InterruptedException | ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INVALID_ARGUMENT.withCause(e).toRuntimeException(); } } diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/DoPutPreparedStatementResultPOJO.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/DoPutPreparedStatementResultPOJO.java new file mode 100644 index 0000000000000..ace78862b014d --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/DoPutPreparedStatementResultPOJO.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.example; + +import java.io.Serializable; + +public class DoPutPreparedStatementResultPOJO implements Serializable { + private String query; + private byte[] parameters; + + public DoPutPreparedStatementResultPOJO(String query, byte[] parameters) { + this.query = query; + this.parameters = parameters.clone(); + } + + public String getQuery() { + return query; + } + + public byte[] getParameters() { + return parameters; + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index 52c402efd6f0b..36362fd8681d3 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -156,21 +156,22 @@ * supports all current features of Flight SQL. */ public class FlightSqlExample implements FlightSqlProducer, AutoCloseable { - private static final String DATABASE_URI = "jdbc:derby:target/derbyDB"; private static final Logger LOGGER = getLogger(FlightSqlExample.class); - private static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar(); + protected static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar(); + public static final String DB_NAME = "derbyDB"; + private final String databaseUri; // ARROW-15315: Use ExecutorService to simulate an async scenario private final ExecutorService executorService = Executors.newFixedThreadPool(10); private final Location location; - private final PoolingDataSource dataSource; - private final BufferAllocator rootAllocator = new RootAllocator(); + protected final PoolingDataSource dataSource; + protected final BufferAllocator rootAllocator = new RootAllocator(); private final Cache> preparedStatementLoadingCache; private final Cache> statementLoadingCache; private final SqlInfoBuilder sqlInfoBuilder; public static void main(String[] args) throws Exception { Location location = Location.forGrpcInsecure("localhost", 55555); - final FlightSqlExample example = new FlightSqlExample(location); + final FlightSqlExample example = new FlightSqlExample(location, DB_NAME); Location listenLocation = Location.forGrpcInsecure("0.0.0.0", 55555); try (final BufferAllocator allocator = new RootAllocator(); final FlightServer server = FlightServer.builder(allocator, listenLocation, example).build()) { @@ -179,13 +180,14 @@ public static void main(String[] args) throws Exception { } } - public FlightSqlExample(final Location location) { + public FlightSqlExample(final Location location, final String dbName) { // TODO Constructor should not be doing work. checkState( - removeDerbyDatabaseIfExists() && populateDerbyDatabase(), + removeDerbyDatabaseIfExists(dbName) && populateDerbyDatabase(dbName), "Failed to reset Derby database!"); + databaseUri = "jdbc:derby:target/" + dbName; final ConnectionFactory connectionFactory = - new DriverManagerConnectionFactory(DATABASE_URI, new Properties()); + new DriverManagerConnectionFactory(databaseUri, new Properties()); final PoolableConnectionFactory poolableConnectionFactory = new PoolableConnectionFactory(connectionFactory, null); final ObjectPool connectionPool = new GenericObjectPool<>(poolableConnectionFactory); @@ -248,9 +250,9 @@ public FlightSqlExample(final Location location) { } - private static boolean removeDerbyDatabaseIfExists() { + public static boolean removeDerbyDatabaseIfExists(final String dbName) { boolean wasSuccess; - final Path path = Paths.get("target" + File.separator + "derbyDB"); + final Path path = Paths.get("target" + File.separator + dbName); try (final Stream walk = Files.walk(path)) { /* @@ -262,7 +264,7 @@ private static boolean removeDerbyDatabaseIfExists() { * this not expected. */ wasSuccess = walk.sorted(Comparator.reverseOrder()).map(Path::toFile).map(File::delete) - .reduce(Boolean::logicalAnd).orElseThrow(IOException::new); + .reduce(Boolean::logicalAnd).orElseThrow(IOException::new); } catch (IOException e) { /* * The only acceptable scenario for an `IOException` to be thrown here is if @@ -277,9 +279,12 @@ private static boolean removeDerbyDatabaseIfExists() { return wasSuccess; } - private static boolean populateDerbyDatabase() { - try (final Connection connection = DriverManager.getConnection("jdbc:derby:target/derbyDB;create=true"); + private static boolean populateDerbyDatabase(final String dbName) { + try (final Connection connection = DriverManager.getConnection("jdbc:derby:target/" + dbName + ";create=true"); Statement statement = connection.createStatement()) { + + dropTable(statement, "intTable"); + dropTable(statement, "foreignTable"); statement.execute("CREATE TABLE foreignTable (" + "id INT not null primary key GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " + "foreignName varchar(100), " + @@ -302,6 +307,18 @@ private static boolean populateDerbyDatabase() { return true; } + private static void dropTable(final Statement statement, final String tableName) throws SQLException { + try { + statement.execute("DROP TABLE " + tableName); + } catch (SQLException e) { + // sql error code for "object does not exist"; which is fine, we're trying to delete the table + // see https://db.apache.org/derby/docs/10.17/ref/rrefexcept71493.html + if (!"42Y55".equals(e.getSQLState())) { + throw e; + } + } + } + private static ArrowType getArrowTypeFromJdbcType(final int jdbcDataType, final int precision, final int scale) { try { return JdbcToArrowUtils.getArrowTypeFromJdbcType(new JdbcFieldInfo(jdbcDataType, precision, scale), @@ -778,7 +795,7 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r // Running on another thread Future unused = executorService.submit(() -> { try { - final ByteString preparedStatementHandle = copyFrom(randomUUID().toString().getBytes(StandardCharsets.UTF_8)); + final ByteString preparedStatementHandle = copyFrom(request.getQuery().getBytes(StandardCharsets.UTF_8)); // Ownership of the connection will be passed to the context. Do NOT close! final Connection connection = dataSource.getConnection(); final PreparedStatement preparedStatement = connection.prepareStatement(request.getQuery(), @@ -882,7 +899,7 @@ public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate while (binder.next()) { preparedStatement.addBatch(); } - int[] recordCounts = preparedStatement.executeBatch(); + final int[] recordCounts = preparedStatement.executeBatch(); recordCount = Arrays.stream(recordCounts).sum(); } @@ -928,6 +945,7 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co .toRuntimeException()); return; } + ackStream.onCompleted(); }; } @@ -1035,7 +1053,7 @@ public void getStreamTables(final CommandGetTables command, final CallContext co final String[] tableTypes = protocolSize == 0 ? null : protocolStringList.toArray(new String[protocolSize]); - try (final Connection connection = DriverManager.getConnection(DATABASE_URI); + try (final Connection connection = DriverManager.getConnection(databaseUri); final VectorSchemaRoot vectorSchemaRoot = getTablesRoot( connection.getMetaData(), rootAllocator, @@ -1086,7 +1104,7 @@ public void getStreamPrimaryKeys(final CommandGetPrimaryKeys command, final Call final String schema = command.hasDbSchema() ? command.getDbSchema() : null; final String table = command.getTable(); - try (Connection connection = DriverManager.getConnection(DATABASE_URI)) { + try (Connection connection = DriverManager.getConnection(databaseUri)) { final ResultSet primaryKeys = connection.getMetaData().getPrimaryKeys(catalog, schema, table); final VarCharVector catalogNameVector = new VarCharVector("catalog_name", rootAllocator); @@ -1140,7 +1158,7 @@ public void getStreamExportedKeys(final CommandGetExportedKeys command, final Ca String schema = command.hasDbSchema() ? command.getDbSchema() : null; String table = command.getTable(); - try (Connection connection = DriverManager.getConnection(DATABASE_URI); + try (Connection connection = DriverManager.getConnection(databaseUri); ResultSet keys = connection.getMetaData().getExportedKeys(catalog, schema, table); VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { listener.start(vectorSchemaRoot); @@ -1165,7 +1183,7 @@ public void getStreamImportedKeys(final CommandGetImportedKeys command, final Ca String schema = command.hasDbSchema() ? command.getDbSchema() : null; String table = command.getTable(); - try (Connection connection = DriverManager.getConnection(DATABASE_URI); + try (Connection connection = DriverManager.getConnection(databaseUri); ResultSet keys = connection.getMetaData().getImportedKeys(catalog, schema, table); VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { listener.start(vectorSchemaRoot); @@ -1193,7 +1211,7 @@ public void getStreamCrossReference(CommandGetCrossReference command, CallContex final String pkTable = command.getPkTable(); final String fkTable = command.getFkTable(); - try (Connection connection = DriverManager.getConnection(DATABASE_URI); + try (Connection connection = DriverManager.getConnection(databaseUri); ResultSet keys = connection.getMetaData() .getCrossReference(pkCatalog, pkSchema, pkTable, fkCatalog, fkSchema, fkTable); VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { @@ -1280,7 +1298,7 @@ public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, } } - private FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, + protected FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, final Schema schema) { final Ticket ticket = new Ticket(pack(request).toByteArray()); // TODO Support multiple endpoints. diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java new file mode 100644 index 0000000000000..c79c09c0967dc --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.example; + +import static java.lang.String.format; +import static org.apache.arrow.adapter.jdbc.JdbcToArrow.sqlToArrowVectorIterator; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema; +import static org.apache.arrow.flight.sql.impl.FlightSql.*; +import static org.slf4j.LoggerFactory.getLogger; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.StreamCorruptedException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; + +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcParameterBinder; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.ArrowFileReader; +import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.ipc.SeekableReadChannel; +import org.apache.arrow.vector.ipc.message.ArrowBlock; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; +import org.slf4j.Logger; + +import com.google.protobuf.ByteString; + +/** + * Example {@link FlightSqlProducer} implementation showing an Apache Derby backed Flight SQL server that generally + * supports all current features of Flight SQL. + */ +public class FlightSqlStatelessExample extends FlightSqlExample { + private static final Logger LOGGER = getLogger(FlightSqlStatelessExample.class); + public static final String DB_NAME = "derbyStatelessDB"; + + + public FlightSqlStatelessExample(final Location location, final String dbName) { + super(location, dbName); + } + + @Override + public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + + return () -> { + final String query = new String(command.getPreparedStatementHandle().toStringUtf8()); + try (Connection connection = dataSource.getConnection(); + PreparedStatement preparedStatement = createPreparedStatement(connection, query)) { + while (flightStream.next()) { + final VectorSchemaRoot root = flightStream.getRoot(); + final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); + while (binder.next()) { + // Do not execute() - will be done in a getStream call + } + + final ByteArrayOutputStream parametersStream = new ByteArrayOutputStream(); + try (ArrowFileWriter writer = new ArrowFileWriter(root, null, Channels.newChannel(parametersStream)) + ) { + writer.start(); + writer.writeBatch(); + } + + if (parametersStream.size() > 0) { + final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = + new DoPutPreparedStatementResultPOJO(query, parametersStream.toByteArray()); + + final byte[] doPutPreparedStatementResultPOJOArr = serializePOJO(doPutPreparedStatementResultPOJO); + final DoPutPreparedStatementResult doPutPreparedStatementResult = + DoPutPreparedStatementResult.newBuilder() + .setPreparedStatementHandle( + ByteString.copyFrom(ByteBuffer.wrap(doPutPreparedStatementResultPOJOArr))) + .build(); + + try (final ArrowBuf buffer = rootAllocator.buffer(doPutPreparedStatementResult.getSerializedSize())) { + buffer.writeBytes(doPutPreparedStatementResult.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + } + } + } + + } catch (SQLException | IOException e) { + ackStream.onError(CallStatus.INTERNAL + .withDescription("Failed to bind parameters: " + e.getMessage()) + .withCause(e) + .toRuntimeException()); + return; + } + + ackStream.onCompleted(); + }; + } + + @Override + public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, + final ServerStreamListener listener) { + final byte[] handle = command.getPreparedStatementHandle().toByteArray(); + try { + // Case where there are parameters + try { + final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = + deserializePOJO(handle); + final String query = doPutPreparedStatementResultPOJO.getQuery(); + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = createPreparedStatement(connection, query); + ArrowFileReader reader = new ArrowFileReader(new SeekableReadChannel( + new ByteArrayReadableSeekableByteChannel( + doPutPreparedStatementResultPOJO.getParameters())), rootAllocator)) { + + for (ArrowBlock arrowBlock : reader.getRecordBlocks()) { + reader.loadRecordBatch(arrowBlock); + VectorSchemaRoot vectorSchemaRootRecover = reader.getVectorSchemaRoot(); + JdbcParameterBinder binder = JdbcParameterBinder.builder(statement, vectorSchemaRootRecover) + .bindAll().build(); + + while (binder.next()) { + executeQuery(statement, listener); + } + } + } + } catch (StreamCorruptedException e) { + // Case where there are no parameters + final String query = new String(command.getPreparedStatementHandle().toStringUtf8()); + try (Connection connection = dataSource.getConnection(); + PreparedStatement preparedStatement = createPreparedStatement(connection, query)) { + executeQuery(preparedStatement, listener); + } + } + } catch (final SQLException | IOException | ClassNotFoundException e) { + LOGGER.error(format("Failed to getStreamPreparedStatement: <%s>.", e.getMessage()), e); + listener.error(CallStatus.INTERNAL.withDescription("Failed to prepare statement: " + e).toRuntimeException()); + } finally { + listener.completed(); + } + } + + private void executeQuery(PreparedStatement statement, + final ServerStreamListener listener) throws IOException, SQLException { + try (final ResultSet resultSet = statement.executeQuery()) { + final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR); + try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) { + final VectorLoader loader = new VectorLoader(vectorSchemaRoot); + listener.start(vectorSchemaRoot); + + final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator); + while (iterator.hasNext()) { + final VectorSchemaRoot batch = iterator.next(); + if (batch.getRowCount() == 0) { + break; + } + final VectorUnloader unloader = new VectorUnloader(batch); + loader.load(unloader.getRecordBatch()); + listener.putNext(); + vectorSchemaRoot.clear(); + } + listener.putNext(); + } + } + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command, + final CallContext context, + final FlightDescriptor descriptor) { + final byte[] handle = command.getPreparedStatementHandle().toByteArray(); + try { + String query; + try { + query = deserializePOJO(handle).getQuery(); + } catch (StreamCorruptedException e) { + query = new String(command.getPreparedStatementHandle().toStringUtf8()); + } + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = createPreparedStatement(connection, query)) { + ResultSetMetaData metaData = statement.getMetaData(); + return getFlightInfoForSchema(command, descriptor, + jdbcToArrowSchema(metaData, DEFAULT_CALENDAR)); + } + } catch (final SQLException | IOException | ClassNotFoundException e) { + LOGGER.error(format("There was a problem executing the prepared statement: <%s>.", e.getMessage()), e); + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); + } + } + + private DoPutPreparedStatementResultPOJO deserializePOJO(byte[] handle) throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(handle); + ObjectInputStream ois = new ObjectInputStream(bis)) { + return (DoPutPreparedStatementResultPOJO) ois.readObject(); + } + } + + private byte[] serializePOJO(DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(bos)) { + oos.writeObject(doPutPreparedStatementResultPOJO); + return bos.toByteArray(); + } + } + + private PreparedStatement createPreparedStatement(Connection connection, String query) throws SQLException { + return connection.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java index a39736e939f0b..ffffdd62ac950 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java @@ -87,63 +87,72 @@ public class TestFlightSql { Field.nullable("FOREIGNID", MinorType.INT.getType()))); private static final List> EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY = ImmutableList.of( asList("1", "one", "1", "1"), asList("2", "zero", "0", "1"), asList("3", "negative one", "-1", "1")); - private static final List> EXPECTED_RESULTS_FOR_PARAMETER_BINDING = ImmutableList.of( + protected static final List> EXPECTED_RESULTS_FOR_PARAMETER_BINDING = ImmutableList.of( asList("1", "one", "1", "1")); private static final Map GET_SQL_INFO_EXPECTED_RESULTS_MAP = new LinkedHashMap<>(); - private static final String LOCALHOST = "localhost"; - private static BufferAllocator allocator; - private static FlightServer server; - private static FlightSqlClient sqlClient; + protected static final String LOCALHOST = "localhost"; + protected static BufferAllocator allocator; + protected static FlightServer server; + protected static FlightSqlClient sqlClient; @BeforeAll public static void setUp() throws Exception { + setUpClientServer(); + setUpExpectedResultsMap(); + } + + private static void setUpClientServer() throws Exception { allocator = new RootAllocator(Integer.MAX_VALUE); final Location serverLocation = Location.forGrpcInsecure(LOCALHOST, 0); - server = FlightServer.builder(allocator, serverLocation, new FlightSqlExample(serverLocation)) - .build() - .start(); + server = FlightServer.builder(allocator, serverLocation, + new FlightSqlExample(serverLocation, FlightSqlExample.DB_NAME)) + .build() + .start(); final Location clientLocation = Location.forGrpcInsecure(LOCALHOST, server.getPort()); sqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build()); + } + protected static void setUpExpectedResultsMap() { GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE), "Apache Derby"); + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE), "Apache Derby"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION_VALUE), "10.14.2.0 - (1828579)"); + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION_VALUE), "10.14.2.0 - (1828579)"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE), "10.14.2.0 - (1828579)"); + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE), "10.14.2.0 - (1828579)"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE), "false"); + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE), "false"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_ALL_TABLES_ARE_SELECTABLE_VALUE), "true"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_ALL_TABLES_ARE_SELECTABLE_VALUE), "true"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put( - Integer.toString(FlightSql.SqlInfo.SQL_NULL_ORDERING_VALUE), - Integer.toString(FlightSql.SqlNullOrdering.SQL_NULLS_SORTED_AT_END_VALUE)); + .put( + Integer.toString(FlightSql.SqlInfo.SQL_NULL_ORDERING_VALUE), + Integer.toString(FlightSql.SqlNullOrdering.SQL_NULLS_SORTED_AT_END_VALUE)); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_CATALOG_VALUE), "false"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_CATALOG_VALUE), "false"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_SCHEMA_VALUE), "true"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_SCHEMA_VALUE), "true"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_TABLE_VALUE), "true"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_TABLE_VALUE), "true"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put( - Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_CASE_VALUE), - Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE_VALUE)); + .put( + Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_CASE_VALUE), + Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE_VALUE)); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR_VALUE), "\""); + .put(Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR_VALUE), "\""); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put( - Integer.toString(FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE_VALUE), - Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE_VALUE)); + .put( + Integer.toString(FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE_VALUE), + Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE_VALUE)); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_MAX_COLUMNS_IN_TABLE_VALUE), "42"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_MAX_COLUMNS_IN_TABLE_VALUE), "42"); } @AfterAll public static void tearDown() throws Exception { close(sqlClient, server, allocator); + FlightSqlExample.removeDerbyDatabaseIfExists(FlightSqlExample.DB_NAME); } private static List> getNonConformingResultsForGetSqlInfo(final List> results) { diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java new file mode 100644 index 0000000000000..09c7b2ef87f45 --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.test; + +import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults; +import static org.apache.arrow.util.AutoCloseables.close; +import static org.hamcrest.CoreMatchers.*; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; +import org.apache.arrow.flight.sql.example.FlightSqlStatelessExample; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.hamcrest.MatcherAssert; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +/** + * Test direct usage of Flight SQL workflows. + */ +public class TestFlightSqlStateless extends TestFlightSql { + + @BeforeAll + public static void setUp() throws Exception { + setUpClientServer(); + setUpExpectedResultsMap(); + } + + @AfterAll + public static void tearDown() throws Exception { + close(sqlClient, server, allocator); + FlightSqlStatelessExample.removeDerbyDatabaseIfExists(FlightSqlStatelessExample.DB_NAME); + } + + private static void setUpClientServer() throws Exception { + allocator = new RootAllocator(Integer.MAX_VALUE); + + final Location serverLocation = Location.forGrpcInsecure(LOCALHOST, 0); + server = FlightServer.builder(allocator, serverLocation, + new FlightSqlStatelessExample(serverLocation, FlightSqlStatelessExample.DB_NAME)) + .build() + .start(); + + final Location clientLocation = Location.forGrpcInsecure(LOCALHOST, server.getPort()); + sqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build()); + } + + @Override + @Test + public void testSimplePreparedStatementResultsWithParameterBinding() throws Exception { + try (PreparedStatement prepare = sqlClient.prepare("SELECT * FROM intTable WHERE id = ?")) { + final Schema parameterSchema = prepare.getParameterSchema(); + try (final VectorSchemaRoot insertRoot = VectorSchemaRoot.create(parameterSchema, allocator)) { + insertRoot.allocateNew(); + + final IntVector valueVector = (IntVector) insertRoot.getVector(0); + valueVector.setSafe(0, 1); + insertRoot.setRowCount(1); + + prepare.setParameters(insertRoot); + final FlightInfo flightInfo = prepare.execute(); + + for (FlightEndpoint endpoint: flightInfo.getEndpoints()) { + try (FlightStream stream = sqlClient.getStream(endpoint.getTicket())) { + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), + () -> MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_PARAMETER_BINDING)) + ); + } + } + } + } + } +}