Skip to content

Commit 519e7cc

Browse files
committed
Prevent concurrent uses of #ping and #close
Fix: brianmario#1433 Apply the same locking mecanism `#query` uses to these two methods.
1 parent d694a45 commit 519e7cc

3 files changed

Lines changed: 101 additions & 30 deletions

File tree

ext/mysql2/client.c

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,26 @@ static VALUE allocate(VALUE klass) {
420420
return obj;
421421
}
422422

423+
static void rb_mysql_client_set_active_fiber(VALUE self, bool closing) {
424+
VALUE fiber_current = rb_fiber_current();
425+
GET_CLIENT(self);
426+
427+
// see if this connection is still waiting on a result from a previous query
428+
if (NIL_P(wrapper->active_fiber) || (closing && !rb_fiber_alive_p(wrapper->active_fiber))) {
429+
// mark this connection active
430+
wrapper->active_fiber = fiber_current;
431+
} else if (wrapper->active_fiber == fiber_current) {
432+
if (!closing) {
433+
rb_raise(cMysql2Error, "This connection is still waiting for a result, try again once you have the result");
434+
}
435+
} else {
436+
VALUE inspect = rb_inspect(wrapper->active_fiber);
437+
const char *thr = StringValueCStr(inspect);
438+
439+
rb_raise(cMysql2Error, "This connection is in use by: %s", thr);
440+
}
441+
}
442+
423443
/* call-seq:
424444
* Mysql2::Client.escape(string)
425445
*
@@ -571,11 +591,14 @@ static VALUE rb_mysql_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VA
571591
*/
572592
static VALUE rb_mysql_client_close(VALUE self) {
573593
GET_CLIENT(self);
594+
rb_mysql_client_set_active_fiber(self, true);
574595

575596
if (wrapper->client) {
576597
rb_thread_call_without_gvl(nogvl_close, wrapper, RUBY_UBF_IO, 0);
577598
}
578599

600+
wrapper->active_fiber = Qnil;
601+
579602
return Qnil;
580603
}
581604

@@ -798,24 +821,6 @@ static VALUE disconnect_and_mark_inactive(VALUE self) {
798821
return Qnil;
799822
}
800823

801-
static void rb_mysql_client_set_active_fiber(VALUE self) {
802-
VALUE fiber_current = rb_fiber_current();
803-
GET_CLIENT(self);
804-
805-
// see if this connection is still waiting on a result from a previous query
806-
if (NIL_P(wrapper->active_fiber)) {
807-
// mark this connection active
808-
wrapper->active_fiber = fiber_current;
809-
} else if (wrapper->active_fiber == fiber_current) {
810-
rb_raise(cMysql2Error, "This connection is still waiting for a result, try again once you have the result");
811-
} else {
812-
VALUE inspect = rb_inspect(wrapper->active_fiber);
813-
const char *thr = StringValueCStr(inspect);
814-
815-
rb_raise(cMysql2Error, "This connection is in use by: %s", thr);
816-
}
817-
}
818-
819824
/* call-seq:
820825
* client.abandon_results!
821826
*
@@ -873,7 +878,7 @@ static VALUE rb_mysql_query(VALUE self, VALUE sql, VALUE current) {
873878
args.sql_len = RSTRING_LEN(args.sql);
874879
args.wrapper = wrapper;
875880

876-
rb_mysql_client_set_active_fiber(self);
881+
rb_mysql_client_set_active_fiber(self, false);
877882

878883
#ifndef _WIN32
879884
rb_rescue2(do_send_query, (VALUE)&args, disconnect_and_raise, self, rb_eException, (VALUE)0);
@@ -1233,12 +1238,16 @@ static void *nogvl_ping(void *ptr) {
12331238
*/
12341239
static VALUE rb_mysql_client_ping(VALUE self) {
12351240
GET_CLIENT(self);
1241+
rb_mysql_client_set_active_fiber(self, false);
12361242

1243+
VALUE result = Qnil;
12371244
if (!CONNECTED(wrapper)) {
1238-
return Qfalse;
1245+
result = Qfalse;
12391246
} else {
1240-
return (VALUE)rb_thread_call_without_gvl(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0);
1247+
result = (VALUE)rb_thread_call_without_gvl(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0);
12411248
}
1249+
wrapper->active_fiber = Qnil;
1250+
return result;
12421251
}
12431252

12441253
/* call-seq:

spec/mysql2/client_spec.rb

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def run_gc
608608

609609
it "should detect closed connection on query read error" do
610610
connection_id = @client.thread_id
611-
Thread.new do
611+
new_thread do
612612
sleep(0.1)
613613
Mysql2::Client.new(DatabaseCredentials['root']).tap do |supervisor|
614614
supervisor.query("KILL #{connection_id}")
@@ -633,14 +633,32 @@ def run_gc
633633
end.to raise_error(Mysql2::Error)
634634
end
635635

636+
it "should prevent using a connection held by a dead thread, but not closing it" do
637+
thr = new_thread do
638+
@client.query("SELECT SLEEP(2)")
639+
end
640+
thr.join(0.5)
641+
thr.kill
642+
thr.join
643+
644+
expect { @client.query("SELECT 4") }.to raise_error(Mysql2::Error)
645+
@client.close
646+
end
647+
636648
it "should describe the thread holding the active query" do
637-
thr = Thread.new do
649+
out_queue = Queue.new
650+
in_queue = Queue.new
651+
652+
thr = new_thread do
638653
@client.query("SELECT 1", async: true)
639-
Fiber.current
654+
out_queue << Fiber.current
655+
in_queue.pop
640656
end
641657

642-
fiber = thr.value
658+
fiber = out_queue.pop
643659
expect { @client.query('SELECT 1') }.to raise_error(Mysql2::Error, Regexp.new(Regexp.escape(fiber.inspect)))
660+
in_queue.close
661+
thr.join
644662
end
645663

646664
it "should timeout if we wait longer than :read_timeout" do
@@ -742,7 +760,7 @@ def run_gc
742760
# Note that each thread opens its own database connection
743761
start = clock_time
744762
threads = Array.new(5) do
745-
Thread.new do
763+
new_thread do
746764
new_client do |client|
747765
client.query("SELECT SLEEP(#{sleep_time})")
748766
end
@@ -773,6 +791,12 @@ def run_gc
773791
result = @client.async_result
774792
expect(result).to be_an_instance_of(Mysql2::Result)
775793
end
794+
795+
it "can close a connection with on the fly async query" do
796+
expect(@client.query("SELECT sleep(0.5)", async: true)).to eql(nil)
797+
@client.close
798+
expect(@client.async_result).to be nil
799+
end
776800
end
777801

778802
context "Multiple results sets" do
@@ -875,13 +899,13 @@ def run_gc
875899

876900
it "should not overflow the thread stack" do
877901
expect do
878-
Thread.new { Mysql2::Client.escape("'" * 256 * 1024) }.join
902+
new_thread { Mysql2::Client.escape("'" * 256 * 1024) }.join
879903
end.not_to raise_error
880904
end
881905

882906
it "should not overflow the process stack" do
883907
expect do
884-
Thread.new { Mysql2::Client.escape("'" * 1024 * 1024 * 4) }.join
908+
new_thread { Mysql2::Client.escape("'" * 1024 * 1024 * 4) }.join
885909
end.not_to raise_error
886910
end
887911

@@ -912,13 +936,13 @@ def run_gc
912936

913937
it "should not overflow the thread stack" do
914938
expect do
915-
Thread.new { @client.escape("'" * 256 * 1024) }.join
939+
new_thread { @client.escape("'" * 256 * 1024) }.join
916940
end.not_to raise_error
917941
end
918942

919943
it "should not overflow the process stack" do
920944
expect do
921-
Thread.new { @client.escape("'" * 1024 * 1024 * 4) }.join
945+
new_thread { @client.escape("'" * 1024 * 1024 * 4) }.join
922946
end.not_to raise_error
923947
end
924948

@@ -1279,4 +1303,30 @@ def connect(*args); end
12791303
expect(client.inspect).not_to include("pass")
12801304
expect(client.inspect).not_to include("secretsecret")
12811305
end
1306+
1307+
it "should not allow concurrent use of #ping" do
1308+
@client.ping
1309+
thread = new_thread { @client.query("SELECT SLEEP(1)") }
1310+
thread.join(0.1)
1311+
10.times do
1312+
expect do
1313+
@client.ping
1314+
end.to raise_error(Mysql2::Error, /This connection is in use by/)
1315+
end
1316+
expect(thread.value.to_a).to eq([{ "SLEEP(1)" => 0 }])
1317+
expect(@client.ping).to eq(true)
1318+
end
1319+
1320+
it "should not allow concurrent use of #close" do
1321+
@client.ping
1322+
thread = new_thread { @client.query("SELECT SLEEP(1)") }
1323+
thread.join(0.1)
1324+
10.times do
1325+
expect do
1326+
@client.close
1327+
end.to raise_error(Mysql2::Error, /This connection is in use by/)
1328+
end
1329+
expect(thread.value.to_a).to eq([{ "SLEEP(1)" => 0 }])
1330+
expect(@client.close).to be_nil
1331+
end
12821332
end

spec/spec_helper.rb

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def new_client(option_overrides = {})
5151
end
5252
end
5353

54+
def new_thread(&block)
55+
@threads ||= []
56+
@threads << (thr = Thread.new(&block))
57+
thr
58+
end
59+
5460
def num_classes
5561
# rubocop:disable Lint/UnifiedInteger
5662
0.instance_of?(Integer) ? [Integer] : [Fixnum, Bignum]
@@ -178,6 +184,12 @@ def ssl_cert_host
178184
end
179185

180186
config.after(:example) do
187+
if @threads
188+
@threads.each do |thr|
189+
thr.kill
190+
thr.join
191+
end
192+
end
181193
@clients.each(&:close)
182194
end
183195
end

0 commit comments

Comments
 (0)