Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions ext/mysql2/client.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <mysql2_ext.h>

#include <stdbool.h>
#include <time.h>
#include <errno.h>
#ifndef _WIN32
Expand Down Expand Up @@ -420,6 +421,26 @@ static VALUE allocate(VALUE klass) {
return obj;
}

static void rb_mysql_client_set_active_fiber(VALUE self, bool closing) {
VALUE fiber_current = rb_fiber_current();
GET_CLIENT(self);

// see if this connection is still waiting on a result from a previous query
if (NIL_P(wrapper->active_fiber) || (closing && !rb_fiber_alive_p(wrapper->active_fiber))) {
// mark this connection active
wrapper->active_fiber = fiber_current;
} else if (wrapper->active_fiber == fiber_current) {
if (!closing) {
rb_raise(cMysql2Error, "This connection is still waiting for a result, try again once you have the result");
}
} else {
VALUE inspect = rb_inspect(wrapper->active_fiber);
const char *thr = StringValueCStr(inspect);

rb_raise(cMysql2Error, "This connection is in use by: %s", thr);
}
}

/* call-seq:
* Mysql2::Client.escape(string)
*
Expand Down Expand Up @@ -571,11 +592,14 @@ static VALUE rb_mysql_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VA
*/
static VALUE rb_mysql_client_close(VALUE self) {
GET_CLIENT(self);
rb_mysql_client_set_active_fiber(self, true);

if (wrapper->client) {
rb_thread_call_without_gvl(nogvl_close, wrapper, RUBY_UBF_IO, 0);
}

wrapper->active_fiber = Qnil;

return Qnil;
}

Expand Down Expand Up @@ -798,24 +822,6 @@ static VALUE disconnect_and_mark_inactive(VALUE self) {
return Qnil;
}

static void rb_mysql_client_set_active_fiber(VALUE self) {
VALUE fiber_current = rb_fiber_current();
GET_CLIENT(self);

// see if this connection is still waiting on a result from a previous query
if (NIL_P(wrapper->active_fiber)) {
// mark this connection active
wrapper->active_fiber = fiber_current;
} else if (wrapper->active_fiber == fiber_current) {
rb_raise(cMysql2Error, "This connection is still waiting for a result, try again once you have the result");
} else {
VALUE inspect = rb_inspect(wrapper->active_fiber);
const char *thr = StringValueCStr(inspect);

rb_raise(cMysql2Error, "This connection is in use by: %s", thr);
}
}

/* call-seq:
* client.abandon_results!
*
Expand Down Expand Up @@ -873,7 +879,7 @@ static VALUE rb_mysql_query(VALUE self, VALUE sql, VALUE current) {
args.sql_len = RSTRING_LEN(args.sql);
args.wrapper = wrapper;

rb_mysql_client_set_active_fiber(self);
rb_mysql_client_set_active_fiber(self, false);

#ifndef _WIN32
rb_rescue2(do_send_query, (VALUE)&args, disconnect_and_raise, self, rb_eException, (VALUE)0);
Expand Down Expand Up @@ -1233,12 +1239,16 @@ static void *nogvl_ping(void *ptr) {
*/
static VALUE rb_mysql_client_ping(VALUE self) {
GET_CLIENT(self);
rb_mysql_client_set_active_fiber(self, false);

VALUE result = Qnil;
if (!CONNECTED(wrapper)) {
return Qfalse;
result = Qfalse;
} else {
return (VALUE)rb_thread_call_without_gvl(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0);
result = (VALUE)rb_thread_call_without_gvl(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0);
}
wrapper->active_fiber = Qnil;
return result;
}

/* call-seq:
Expand Down
76 changes: 66 additions & 10 deletions spec/mysql2/client_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def run_gc

it "should detect closed connection on query read error" do
connection_id = @client.thread_id
Thread.new do
new_thread do
sleep(0.1)
Mysql2::Client.new(DatabaseCredentials['root']).tap do |supervisor|
supervisor.query("KILL #{connection_id}")
Expand All @@ -633,14 +633,32 @@ def run_gc
end.to raise_error(Mysql2::Error)
end

it "should prevent using a connection held by a dead thread, but not closing it" do
thr = new_thread do
@client.query("SELECT SLEEP(2)")
end
thr.join(0.5)
thr.kill
thr.join

expect { @client.query("SELECT 4") }.to raise_error(Mysql2::Error)
@client.close
end

it "should describe the thread holding the active query" do
thr = Thread.new do
out_queue = Queue.new
in_queue = Queue.new

thr = new_thread do
@client.query("SELECT 1", async: true)
Fiber.current
out_queue << Fiber.current
in_queue.pop
end

fiber = thr.value
fiber = out_queue.pop
expect { @client.query('SELECT 1') }.to raise_error(Mysql2::Error, Regexp.new(Regexp.escape(fiber.inspect)))
in_queue.close
thr.join
end

it "should timeout if we wait longer than :read_timeout" do
Expand Down Expand Up @@ -742,7 +760,7 @@ def run_gc
# Note that each thread opens its own database connection
start = clock_time
threads = Array.new(5) do
Thread.new do
new_thread do
new_client do |client|
client.query("SELECT SLEEP(#{sleep_time})")
end
Expand Down Expand Up @@ -773,6 +791,12 @@ def run_gc
result = @client.async_result
expect(result).to be_an_instance_of(Mysql2::Result)
end

it "can close a connection with on the fly async query" do
expect(@client.query("SELECT sleep(0.5)", async: true)).to eql(nil)
@client.close
expect(@client.async_result).to be nil
end
end

context "Multiple results sets" do
Expand Down Expand Up @@ -875,13 +899,13 @@ def run_gc

it "should not overflow the thread stack" do
expect do
Thread.new { Mysql2::Client.escape("'" * 256 * 1024) }.join
new_thread { Mysql2::Client.escape("'" * 256 * 1024) }.join
end.not_to raise_error
end

it "should not overflow the process stack" do
expect do
Thread.new { Mysql2::Client.escape("'" * 1024 * 1024 * 4) }.join
new_thread { Mysql2::Client.escape("'" * 1024 * 1024 * 4) }.join
end.not_to raise_error
end

Expand Down Expand Up @@ -912,13 +936,13 @@ def run_gc

it "should not overflow the thread stack" do
expect do
Thread.new { @client.escape("'" * 256 * 1024) }.join
new_thread { @client.escape("'" * 256 * 1024) }.join
end.not_to raise_error
end

it "should not overflow the process stack" do
expect do
Thread.new { @client.escape("'" * 1024 * 1024 * 4) }.join
new_thread { @client.escape("'" * 1024 * 1024 * 4) }.join
end.not_to raise_error
end

Expand Down Expand Up @@ -1011,7 +1035,13 @@ def run_gc

it "should raise a Mysql2::Error::ConnectionError exception upon connection failure due to invalid credentials" do
expect do
new_client(host: 'localhost', username: 'asdfasdf8d2h', password: 'asdfasdfw42')
begin
new_client(host: 'localhost', username: 'asdfasdf8d2h', password: 'asdfasdfw42')
rescue Mysql2::Error => e
raise unless e.message.include?('mysql_native_password')

skip("Native password is not supported")
end
end.to raise_error(Mysql2::Error::ConnectionError)

expect do
Expand Down Expand Up @@ -1273,4 +1303,30 @@ def connect(*args); end
expect(client.inspect).not_to include("pass")
expect(client.inspect).not_to include("secretsecret")
end

it "should not allow concurrent use of #ping" do
@client.ping
thread = new_thread { @client.query("SELECT SLEEP(1)") }
thread.join(0.1)
10.times do
expect do
@client.ping
end.to raise_error(Mysql2::Error, /This connection is in use by/)
end
expect(thread.value.to_a).to eq([{ "SLEEP(1)" => 0 }])
expect(@client.ping).to eq(true)
end

it "should not allow concurrent use of #close" do
@client.ping
thread = new_thread { @client.query("SELECT SLEEP(1)") }
thread.join(0.1)
10.times do
expect do
@client.close
end.to raise_error(Mysql2::Error, /This connection is in use by/)
end
expect(thread.value.to_a).to eq([{ "SLEEP(1)" => 0 }])
expect(@client.close).to be_nil
end
end
10 changes: 10 additions & 0 deletions spec/spec_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def new_client(option_overrides = {})
end
end

def new_thread(&block)
@threads << (thr = Thread.new(&block))
thr
end

def num_classes
# rubocop:disable Lint/UnifiedInteger
0.instance_of?(Integer) ? [Integer] : [Fixnum, Bignum]
Expand Down Expand Up @@ -174,10 +179,15 @@ def ssl_cert_host
end

config.before(:example) do
@threads = []
@client = new_client
end

config.after(:example) do
@threads.each do |thr|
thr.kill
thr.join
end
@clients.each(&:close)
end
end
Loading