diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index 10e0c925..926a2cd4 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -1,5 +1,6 @@ #include +#include #include #include #ifndef _WIN32 @@ -420,6 +421,26 @@ static VALUE allocate(VALUE klass) { return obj; } +static void rb_mysql_client_set_active_fiber(VALUE self, bool closing) { + VALUE fiber_current = rb_fiber_current(); + GET_CLIENT(self); + + // see if this connection is still waiting on a result from a previous query + if (NIL_P(wrapper->active_fiber) || (closing && !rb_fiber_alive_p(wrapper->active_fiber))) { + // mark this connection active + wrapper->active_fiber = fiber_current; + } else if (wrapper->active_fiber == fiber_current) { + if (!closing) { + rb_raise(cMysql2Error, "This connection is still waiting for a result, try again once you have the result"); + } + } else { + VALUE inspect = rb_inspect(wrapper->active_fiber); + const char *thr = StringValueCStr(inspect); + + rb_raise(cMysql2Error, "This connection is in use by: %s", thr); + } +} + /* call-seq: * Mysql2::Client.escape(string) * @@ -571,11 +592,14 @@ static VALUE rb_mysql_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VA */ static VALUE rb_mysql_client_close(VALUE self) { GET_CLIENT(self); + rb_mysql_client_set_active_fiber(self, true); if (wrapper->client) { rb_thread_call_without_gvl(nogvl_close, wrapper, RUBY_UBF_IO, 0); } + wrapper->active_fiber = Qnil; + return Qnil; } @@ -798,24 +822,6 @@ static VALUE disconnect_and_mark_inactive(VALUE self) { return Qnil; } -static void rb_mysql_client_set_active_fiber(VALUE self) { - VALUE fiber_current = rb_fiber_current(); - GET_CLIENT(self); - - // see if this connection is still waiting on a result from a previous query - if (NIL_P(wrapper->active_fiber)) { - // mark this connection active - wrapper->active_fiber = fiber_current; - } else if (wrapper->active_fiber == fiber_current) { - rb_raise(cMysql2Error, "This connection is still waiting for a result, try again once you have the result"); - } else { - VALUE inspect = rb_inspect(wrapper->active_fiber); - const char *thr = StringValueCStr(inspect); - - rb_raise(cMysql2Error, "This connection is in use by: %s", thr); - } -} - /* call-seq: * client.abandon_results! * @@ -873,7 +879,7 @@ static VALUE rb_mysql_query(VALUE self, VALUE sql, VALUE current) { args.sql_len = RSTRING_LEN(args.sql); args.wrapper = wrapper; - rb_mysql_client_set_active_fiber(self); + rb_mysql_client_set_active_fiber(self, false); #ifndef _WIN32 rb_rescue2(do_send_query, (VALUE)&args, disconnect_and_raise, self, rb_eException, (VALUE)0); @@ -1233,12 +1239,16 @@ static void *nogvl_ping(void *ptr) { */ static VALUE rb_mysql_client_ping(VALUE self) { GET_CLIENT(self); + rb_mysql_client_set_active_fiber(self, false); + VALUE result = Qnil; if (!CONNECTED(wrapper)) { - return Qfalse; + result = Qfalse; } else { - return (VALUE)rb_thread_call_without_gvl(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0); + result = (VALUE)rb_thread_call_without_gvl(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0); } + wrapper->active_fiber = Qnil; + return result; } /* call-seq: diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index 2c4de37d..c4bc814f 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -608,7 +608,7 @@ def run_gc it "should detect closed connection on query read error" do connection_id = @client.thread_id - Thread.new do + new_thread do sleep(0.1) Mysql2::Client.new(DatabaseCredentials['root']).tap do |supervisor| supervisor.query("KILL #{connection_id}") @@ -633,14 +633,32 @@ def run_gc end.to raise_error(Mysql2::Error) end + it "should prevent using a connection held by a dead thread, but not closing it" do + thr = new_thread do + @client.query("SELECT SLEEP(2)") + end + thr.join(0.5) + thr.kill + thr.join + + expect { @client.query("SELECT 4") }.to raise_error(Mysql2::Error) + @client.close + end + it "should describe the thread holding the active query" do - thr = Thread.new do + out_queue = Queue.new + in_queue = Queue.new + + thr = new_thread do @client.query("SELECT 1", async: true) - Fiber.current + out_queue << Fiber.current + in_queue.pop end - fiber = thr.value + fiber = out_queue.pop expect { @client.query('SELECT 1') }.to raise_error(Mysql2::Error, Regexp.new(Regexp.escape(fiber.inspect))) + in_queue.close + thr.join end it "should timeout if we wait longer than :read_timeout" do @@ -742,7 +760,7 @@ def run_gc # Note that each thread opens its own database connection start = clock_time threads = Array.new(5) do - Thread.new do + new_thread do new_client do |client| client.query("SELECT SLEEP(#{sleep_time})") end @@ -773,6 +791,12 @@ def run_gc result = @client.async_result expect(result).to be_an_instance_of(Mysql2::Result) end + + it "can close a connection with on the fly async query" do + expect(@client.query("SELECT sleep(0.5)", async: true)).to eql(nil) + @client.close + expect(@client.async_result).to be nil + end end context "Multiple results sets" do @@ -875,13 +899,13 @@ def run_gc it "should not overflow the thread stack" do expect do - Thread.new { Mysql2::Client.escape("'" * 256 * 1024) }.join + new_thread { Mysql2::Client.escape("'" * 256 * 1024) }.join end.not_to raise_error end it "should not overflow the process stack" do expect do - Thread.new { Mysql2::Client.escape("'" * 1024 * 1024 * 4) }.join + new_thread { Mysql2::Client.escape("'" * 1024 * 1024 * 4) }.join end.not_to raise_error end @@ -912,13 +936,13 @@ def run_gc it "should not overflow the thread stack" do expect do - Thread.new { @client.escape("'" * 256 * 1024) }.join + new_thread { @client.escape("'" * 256 * 1024) }.join end.not_to raise_error end it "should not overflow the process stack" do expect do - Thread.new { @client.escape("'" * 1024 * 1024 * 4) }.join + new_thread { @client.escape("'" * 1024 * 1024 * 4) }.join end.not_to raise_error end @@ -1011,7 +1035,13 @@ def run_gc it "should raise a Mysql2::Error::ConnectionError exception upon connection failure due to invalid credentials" do expect do - new_client(host: 'localhost', username: 'asdfasdf8d2h', password: 'asdfasdfw42') + begin + new_client(host: 'localhost', username: 'asdfasdf8d2h', password: 'asdfasdfw42') + rescue Mysql2::Error => e + raise unless e.message.include?('mysql_native_password') + + skip("Native password is not supported") + end end.to raise_error(Mysql2::Error::ConnectionError) expect do @@ -1273,4 +1303,30 @@ def connect(*args); end expect(client.inspect).not_to include("pass") expect(client.inspect).not_to include("secretsecret") end + + it "should not allow concurrent use of #ping" do + @client.ping + thread = new_thread { @client.query("SELECT SLEEP(1)") } + thread.join(0.1) + 10.times do + expect do + @client.ping + end.to raise_error(Mysql2::Error, /This connection is in use by/) + end + expect(thread.value.to_a).to eq([{ "SLEEP(1)" => 0 }]) + expect(@client.ping).to eq(true) + end + + it "should not allow concurrent use of #close" do + @client.ping + thread = new_thread { @client.query("SELECT SLEEP(1)") } + thread.join(0.1) + 10.times do + expect do + @client.close + end.to raise_error(Mysql2::Error, /This connection is in use by/) + end + expect(thread.value.to_a).to eq([{ "SLEEP(1)" => 0 }]) + expect(@client.close).to be_nil + end end diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index f7ac9a8f..01881749 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -51,6 +51,11 @@ def new_client(option_overrides = {}) end end + def new_thread(&block) + @threads << (thr = Thread.new(&block)) + thr + end + def num_classes # rubocop:disable Lint/UnifiedInteger 0.instance_of?(Integer) ? [Integer] : [Fixnum, Bignum] @@ -174,10 +179,15 @@ def ssl_cert_host end config.before(:example) do + @threads = [] @client = new_client end config.after(:example) do + @threads.each do |thr| + thr.kill + thr.join + end @clients.each(&:close) end end