From f94ad6c0174139439c4a14c3044d78c1362d615e Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Mon, 5 Dec 2016 15:27:48 -0600 Subject: [PATCH 1/2] Fixed issue where disconnect was not reconnecting and added tests --- .../github/pgasync/ConnectionPoolBuilder.java | 10 +- .../pgasync/impl/ConnectionValidator.java | 46 ++++--- .../pgasync/impl/ConnectionValidatorTest.java | 128 ++++++++++++++++++ 3 files changed, 165 insertions(+), 19 deletions(-) create mode 100644 src/test/java/com/github/pgasync/impl/ConnectionValidatorTest.java diff --git a/src/main/java/com/github/pgasync/ConnectionPoolBuilder.java b/src/main/java/com/github/pgasync/ConnectionPoolBuilder.java index bf42be2..b490abc 100644 --- a/src/main/java/com/github/pgasync/ConnectionPoolBuilder.java +++ b/src/main/java/com/github/pgasync/ConnectionPoolBuilder.java @@ -95,6 +95,11 @@ public ConnectionPoolBuilder validationQuery(String validationQuery) { return this; } + public ConnectionPoolBuilder validateSocket(boolean validateSocket) { + properties.validateSocket = validateSocket; + return this; + } + /** * Configuration for a pool. */ @@ -111,6 +116,7 @@ public static class PoolProperties { boolean useSsl; boolean usePipelining; String validationQuery; + boolean validateSocket; public String getHostname() { return hostname; @@ -140,9 +146,7 @@ public DataConverter getDataConverter() { return dataConverter != null ? dataConverter : new DataConverter(converters); } public Func1> getValidator() { - return validationQuery == null || validationQuery.trim().isEmpty() - ? Observable::just - : new ConnectionValidator(validationQuery)::validate; + return new ConnectionValidator(validationQuery, validateSocket)::validate; } } } \ No newline at end of file diff --git a/src/main/java/com/github/pgasync/impl/ConnectionValidator.java b/src/main/java/com/github/pgasync/impl/ConnectionValidator.java index 05e7e2f..3a4f864 100644 --- a/src/main/java/com/github/pgasync/impl/ConnectionValidator.java +++ b/src/main/java/com/github/pgasync/impl/ConnectionValidator.java @@ -26,26 +26,40 @@ public class ConnectionValidator { final String validationQuery; + final boolean validateSocket; - public ConnectionValidator(String validationQuery) { - this.validationQuery = validationQuery; + public ConnectionValidator(String validationQuery, boolean validateSocket) { + // Trimmed as empty means no query for backwards compatibility + this.validationQuery = validationQuery == null || validationQuery.trim().isEmpty() ? null : validationQuery; + this.validateSocket = validateSocket; } public Observable validate(Connection connection) { - return connection.queryRows(validationQuery) - .lift(subscriber -> new Subscriber() { - @Override - public void onError(Throwable e) { - subscriber.onError(e); - } - @Override - public void onCompleted() { - subscriber.onNext(connection); - subscriber.onCompleted(); - } - @Override - public void onNext(Row row) { } - }); + Observable ret = Observable.just(connection); + if (validationQuery != null) { + ret = ret.flatMap(conn -> connection.queryRows(validationQuery) + .lift(subscriber -> new Subscriber() { + @Override + public void onError(Throwable e) { + subscriber.onError(e); + } + @Override + public void onCompleted() { + subscriber.onNext(connection); + subscriber.onCompleted(); + } + @Override + public void onNext(Row row) { } + })); + } + if (validateSocket) { + ret = ret.doOnNext(conn -> { + if (conn instanceof PgConnection && !((PgConnection) conn).isConnected()) { + throw new IllegalStateException("Channel is closed"); + } + }); + } + return ret; } } diff --git a/src/test/java/com/github/pgasync/impl/ConnectionValidatorTest.java b/src/test/java/com/github/pgasync/impl/ConnectionValidatorTest.java new file mode 100644 index 0000000..3bd38b7 --- /dev/null +++ b/src/test/java/com/github/pgasync/impl/ConnectionValidatorTest.java @@ -0,0 +1,128 @@ +package com.github.pgasync.impl; + +import com.github.pgasync.Connection; +import com.github.pgasync.ResultSet; +import com.github.pgasync.SqlException; +import org.junit.Test; +import rx.Observable; + +import java.util.function.Consumer; + +import static org.junit.Assert.*; + +public class ConnectionValidatorTest { + + @Test + public void shouldBeTheSamePidOnSuccessiveCalls() { + withDbr(null, true, dbr -> { + // Simple sanity check for our PID assumptions + assertEquals(selectPid(dbr).toBlocking().single().intValue(), + selectPid(dbr).toBlocking().single().intValue()); + }); + } + + @Test + public void shouldBeSamePidWhenValidationQuerySucceeds() { + withDbr("SELECT 1", false, dbr -> { + // Just compare PIDs + assertEquals(selectPid(dbr).toBlocking().single().intValue(), + selectPid(dbr).toBlocking().single().intValue()); + }); + } + + @Test + public void shouldBeDifferentPidWhenValidationQueryFails() throws Exception { + String errSql = + "DO language plpgsql $$\n" + + " BEGIN\n" + + " IF (SELECT COUNT(1) FROM VSTATE) = 1 THEN\n" + + " RAISE 'ERR';\n" + + " END IF;\n" + + " EXCEPTION\n" + + " WHEN undefined_table THEN\n" + + " END\n" + + "$$;"; + withDbr(errSql, false, dbr -> { + // Add the VSTATE table + dbr.query("DROP TABLE IF EXISTS VSTATE; CREATE TABLE VSTATE (ID VARCHAR(255) PRIMARY KEY)"); + + try { + // Grab the pid + int pid = selectPid(dbr).toBlocking().single(); + + // Break it + runFromOutside(dbr, "INSERT INTO VSTATE VALUES('A')"); + + // Make sure it is broken + try { + selectPid(dbr).toBlocking().single(); + fail("Should be broken"); + } catch (SqlException e) { } + + // Fix it, and go ahead and expect the same PID + runFromOutside(dbr, "TRUNCATE TABLE VSTATE"); + assertEquals(pid, selectPid(dbr).toBlocking().single().intValue()); + } finally { + runFromOutside(dbr, "DROP TABLE IF EXISTS VSTATE"); + } + }); + } + + @Test + public void shouldErrorWhenNotValidatingSocket() { + withDbr(null, false, dbr -> { + // Simple check, kill from outside, confirm failure + assertNotNull(selectPid(dbr).toBlocking().single()); + killConnectionFromOutside(dbr); + try { + selectPid(dbr).toBlocking().single(); + fail("Should not succeed after killing connection"); + } catch (IllegalStateException e) { } + }); + } + + @Test + public void shouldNotErrorWhenValidatingSocket() { + withDbr(null, true, dbr -> { + // Grab pid, kill from outside, confirm different pid + int pid = selectPid(dbr).toBlocking().single(); + killConnectionFromOutside(dbr); + assertNotEquals(pid, selectPid(dbr).toBlocking().single().intValue()); + }); + } + + private static Observable selectPid(DatabaseRule dbr) { + return dbr.db().queryRows("SELECT pg_backend_pid()").map(r -> r.getInt(0)); + } + + private static void killConnectionFromOutside(DatabaseRule dbr) { + ResultSet rs = runFromOutside(dbr, "SELECT pg_terminate_backend(pid) FROM pg_stat_activity " + + "WHERE pid <> pg_backend_pid() AND datname = '" + ((PgConnectionPool) dbr.pool).database + "'"); + assertEquals(1, rs.size()); + // Unfortunately, it appears we have to wait a tiny bit after + // killing the connection for netty to know + try { Thread.sleep(300); } catch (Exception e) { } + } + + private static ResultSet runFromOutside(DatabaseRule dbr, String query) { + PgConnectionPool pool = (PgConnectionPool) dbr.pool; + try(Connection conn = new PgConnection(pool.openStream(pool.address), pool.dataConverter). + connect(pool.username, pool.password, pool.database).toBlocking().single()) { + return conn.querySet(query).toBlocking().single(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static void withDbr(String validationQuery, boolean validateSocket, Consumer fn) { + DatabaseRule rule = new DatabaseRule(); + rule.builder.validationQuery(validationQuery); + rule.builder.validateSocket(validateSocket); + rule.before(); + try { + fn.accept(rule); + } finally { + rule.after(); + } + } +} From 32b1827052be9bb5622cb3b8f596049f37995fce Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Tue, 6 Dec 2016 12:00:19 -0600 Subject: [PATCH 2/2] Change method name to be accurate --- .../java/com/github/pgasync/impl/ConnectionValidatorTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/com/github/pgasync/impl/ConnectionValidatorTest.java b/src/test/java/com/github/pgasync/impl/ConnectionValidatorTest.java index 3bd38b7..fdd418e 100644 --- a/src/test/java/com/github/pgasync/impl/ConnectionValidatorTest.java +++ b/src/test/java/com/github/pgasync/impl/ConnectionValidatorTest.java @@ -31,7 +31,7 @@ public void shouldBeSamePidWhenValidationQuerySucceeds() { } @Test - public void shouldBeDifferentPidWhenValidationQueryFails() throws Exception { + public void shouldFailValidationQueryFailsAndReconnectAfterSuccess() throws Exception { String errSql = "DO language plpgsql $$\n" + " BEGIN\n" +