Skip to content

Commit c13899f

Browse files
committed
Prevent concurrent uses of #ping and #close
Fix: brianmario#1433 Apply the same locking mecanism `#query` uses to these two methods.
1 parent d694a45 commit c13899f

3 files changed

Lines changed: 89 additions & 30 deletions

File tree

ext/mysql2/client.c

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,26 @@ static VALUE allocate(VALUE klass) {
420420
return obj;
421421
}
422422

423+
static void rb_mysql_client_set_active_fiber(VALUE self, bool async_check) {
424+
VALUE fiber_current = rb_fiber_current();
425+
GET_CLIENT(self);
426+
427+
// see if this connection is still waiting on a result from a previous query
428+
if (NIL_P(wrapper->active_fiber) || !rb_fiber_alive_p(wrapper->active_fiber)) {
429+
// mark this connection active
430+
wrapper->active_fiber = fiber_current;
431+
} else if (wrapper->active_fiber == fiber_current) {
432+
if (async_check) {
433+
rb_raise(cMysql2Error, "This connection is still waiting for a result, try again once you have the result");
434+
}
435+
} else {
436+
VALUE inspect = rb_inspect(wrapper->active_fiber);
437+
const char *thr = StringValueCStr(inspect);
438+
439+
rb_raise(cMysql2Error, "This connection is in use by: %s", thr);
440+
}
441+
}
442+
423443
/* call-seq:
424444
* Mysql2::Client.escape(string)
425445
*
@@ -571,11 +591,14 @@ static VALUE rb_mysql_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VA
571591
*/
572592
static VALUE rb_mysql_client_close(VALUE self) {
573593
GET_CLIENT(self);
594+
rb_mysql_client_set_active_fiber(self, false);
574595

575596
if (wrapper->client) {
576597
rb_thread_call_without_gvl(nogvl_close, wrapper, RUBY_UBF_IO, 0);
577598
}
578599

600+
wrapper->active_fiber = Qnil;
601+
579602
return Qnil;
580603
}
581604

@@ -798,24 +821,6 @@ static VALUE disconnect_and_mark_inactive(VALUE self) {
798821
return Qnil;
799822
}
800823

801-
static void rb_mysql_client_set_active_fiber(VALUE self) {
802-
VALUE fiber_current = rb_fiber_current();
803-
GET_CLIENT(self);
804-
805-
// see if this connection is still waiting on a result from a previous query
806-
if (NIL_P(wrapper->active_fiber)) {
807-
// mark this connection active
808-
wrapper->active_fiber = fiber_current;
809-
} else if (wrapper->active_fiber == fiber_current) {
810-
rb_raise(cMysql2Error, "This connection is still waiting for a result, try again once you have the result");
811-
} else {
812-
VALUE inspect = rb_inspect(wrapper->active_fiber);
813-
const char *thr = StringValueCStr(inspect);
814-
815-
rb_raise(cMysql2Error, "This connection is in use by: %s", thr);
816-
}
817-
}
818-
819824
/* call-seq:
820825
* client.abandon_results!
821826
*
@@ -873,7 +878,7 @@ static VALUE rb_mysql_query(VALUE self, VALUE sql, VALUE current) {
873878
args.sql_len = RSTRING_LEN(args.sql);
874879
args.wrapper = wrapper;
875880

876-
rb_mysql_client_set_active_fiber(self);
881+
rb_mysql_client_set_active_fiber(self, true);
877882

878883
#ifndef _WIN32
879884
rb_rescue2(do_send_query, (VALUE)&args, disconnect_and_raise, self, rb_eException, (VALUE)0);
@@ -1233,12 +1238,16 @@ static void *nogvl_ping(void *ptr) {
12331238
*/
12341239
static VALUE rb_mysql_client_ping(VALUE self) {
12351240
GET_CLIENT(self);
1241+
rb_mysql_client_set_active_fiber(self, true);
12361242

1243+
VALUE result = Qnil;
12371244
if (!CONNECTED(wrapper)) {
1238-
return Qfalse;
1245+
result = Qfalse;
12391246
} else {
1240-
return (VALUE)rb_thread_call_without_gvl(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0);
1247+
result = (VALUE)rb_thread_call_without_gvl(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0);
12411248
}
1249+
wrapper->active_fiber = Qnil;
1250+
return result;
12421251
}
12431252

12441253
/* call-seq:

spec/mysql2/client_spec.rb

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def run_gc
608608

609609
it "should detect closed connection on query read error" do
610610
connection_id = @client.thread_id
611-
Thread.new do
611+
new_thread do
612612
sleep(0.1)
613613
Mysql2::Client.new(DatabaseCredentials['root']).tap do |supervisor|
614614
supervisor.query("KILL #{connection_id}")
@@ -634,13 +634,19 @@ def run_gc
634634
end
635635

636636
it "should describe the thread holding the active query" do
637-
thr = Thread.new do
637+
out_queue = Queue.new
638+
in_queue = Queue.new
639+
640+
thr = new_thread do
638641
@client.query("SELECT 1", async: true)
639-
Fiber.current
642+
out_queue << Fiber.current
643+
in_queue.pop
640644
end
641645

642-
fiber = thr.value
646+
fiber = out_queue.pop
643647
expect { @client.query('SELECT 1') }.to raise_error(Mysql2::Error, Regexp.new(Regexp.escape(fiber.inspect)))
648+
in_queue.close
649+
thr.join
644650
end
645651

646652
it "should timeout if we wait longer than :read_timeout" do
@@ -742,7 +748,7 @@ def run_gc
742748
# Note that each thread opens its own database connection
743749
start = clock_time
744750
threads = Array.new(5) do
745-
Thread.new do
751+
new_thread do
746752
new_client do |client|
747753
client.query("SELECT SLEEP(#{sleep_time})")
748754
end
@@ -773,6 +779,12 @@ def run_gc
773779
result = @client.async_result
774780
expect(result).to be_an_instance_of(Mysql2::Result)
775781
end
782+
783+
it "can close a connection with on the fly async query" do
784+
expect(@client.query("SELECT sleep(0.5)", async: true)).to eql(nil)
785+
@client.close
786+
expect(@client.async_result).to be nil
787+
end
776788
end
777789

778790
context "Multiple results sets" do
@@ -875,13 +887,13 @@ def run_gc
875887

876888
it "should not overflow the thread stack" do
877889
expect do
878-
Thread.new { Mysql2::Client.escape("'" * 256 * 1024) }.join
890+
new_thread { Mysql2::Client.escape("'" * 256 * 1024) }.join
879891
end.not_to raise_error
880892
end
881893

882894
it "should not overflow the process stack" do
883895
expect do
884-
Thread.new { Mysql2::Client.escape("'" * 1024 * 1024 * 4) }.join
896+
new_thread { Mysql2::Client.escape("'" * 1024 * 1024 * 4) }.join
885897
end.not_to raise_error
886898
end
887899

@@ -912,13 +924,13 @@ def run_gc
912924

913925
it "should not overflow the thread stack" do
914926
expect do
915-
Thread.new { @client.escape("'" * 256 * 1024) }.join
927+
new_thread { @client.escape("'" * 256 * 1024) }.join
916928
end.not_to raise_error
917929
end
918930

919931
it "should not overflow the process stack" do
920932
expect do
921-
Thread.new { @client.escape("'" * 1024 * 1024 * 4) }.join
933+
new_thread { @client.escape("'" * 1024 * 1024 * 4) }.join
922934
end.not_to raise_error
923935
end
924936

@@ -1279,4 +1291,30 @@ def connect(*args); end
12791291
expect(client.inspect).not_to include("pass")
12801292
expect(client.inspect).not_to include("secretsecret")
12811293
end
1294+
1295+
it "should not allow concurrent use of #ping" do
1296+
@client.ping
1297+
thread = new_thread { @client.query("SELECT SLEEP(1)") }
1298+
thread.join(0.1)
1299+
10.times do
1300+
expect do
1301+
@client.ping
1302+
end.to raise_error(Mysql2::Error, /This connection is in use by/)
1303+
end
1304+
expect(thread.value.to_a).to eq([{ "SLEEP(1)" => 0 }])
1305+
expect(@client.ping).to eq(true)
1306+
end
1307+
1308+
it "should not allow concurrent use of #close" do
1309+
@client.ping
1310+
thread = new_thread { @client.query("SELECT SLEEP(1)") }
1311+
thread.join(0.1)
1312+
10.times do
1313+
expect do
1314+
@client.close
1315+
end.to raise_error(Mysql2::Error, /This connection is in use by/)
1316+
end
1317+
expect(thread.value.to_a).to eq([{ "SLEEP(1)" => 0 }])
1318+
expect(@client.close).to be_nil
1319+
end
12821320
end

spec/spec_helper.rb

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def new_client(option_overrides = {})
5151
end
5252
end
5353

54+
def new_thread(&block)
55+
@threads ||= []
56+
@threads << (thr = Thread.new(&block))
57+
thr
58+
end
59+
5460
def num_classes
5561
# rubocop:disable Lint/UnifiedInteger
5662
0.instance_of?(Integer) ? [Integer] : [Fixnum, Bignum]
@@ -178,6 +184,12 @@ def ssl_cert_host
178184
end
179185

180186
config.after(:example) do
187+
if @threads
188+
@threads.each do |thr|
189+
thr.kill
190+
thr.join
191+
end
192+
end
181193
@clients.each(&:close)
182194
end
183195
end

0 commit comments

Comments
 (0)