Skip to content

Commit bc7b28a

Browse files
Add support for connection reference in DBConnection.connection_module/1 (#256)
1 parent 0938c4e commit bc7b28a

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

lib/db_connection.ex

+11-5
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ defmodule DBConnection do
6868
require Logger
6969

7070
alias DBConnection.Holder
71+
72+
require Holder
73+
7174
defstruct [:pool_ref, :conn_ref, :conn_mode]
7275

7376
defmodule EncodeError do
@@ -1093,19 +1096,22 @@ defmodule DBConnection do
10931096
end
10941097

10951098
@doc """
1096-
Returns connection module used by the given connection pool process.
1099+
Returns connection module used by the given connection pool.
10971100
1098-
If the given process is not a connection pool, `:error` is returned.
1101+
When given a process that is not a connection pool, returns an `:error`.
10991102
"""
1100-
@spec connection_module(pid() | atom()) :: {:ok, module()} | :error
1101-
def connection_module(pool) when is_pid(pool) or is_atom(pool) do
1102-
with pid when pid != nil <- GenServer.whereis(pool),
1103+
@spec connection_module(conn) :: {:ok, module} | :error
1104+
def connection_module(conn) do
1105+
with pid when pid != nil <- pool_pid(conn),
11031106
{:dictionary, dictionary} <- Process.info(pid, :dictionary),
11041107
{:ok, module} <- fetch_from_dictionary(dictionary, @connection_module_key),
11051108
do: {:ok, module},
11061109
else: (_ -> :error)
11071110
end
11081111

1112+
defp pool_pid(%DBConnection{pool_ref: Holder.pool_ref(pool: pid)}), do: pid
1113+
defp pool_pid(conn), do: GenServer.whereis(conn)
1114+
11091115
defp fetch_from_dictionary(dictionary, key) do
11101116
Enum.find_value(dictionary, :error, fn
11111117
{^key, value} -> {:ok, value}

test/db_connection_test.exs

+11
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ defmodule DBConnectionTest do
4343
assert {:ok, TestConnection} = DBConnection.connection_module(name)
4444
end
4545

46+
test "returns the connection module when given a locked connection reference" do
47+
{:ok, agent} = A.start_link([{:ok, :state}, {:idle, :state}, {:idle, :state}])
48+
49+
opts = [agent: agent]
50+
{:ok, pool} = P.start_link(opts)
51+
52+
P.run(pool, fn conn ->
53+
assert {:ok, TestConnection} = DBConnection.connection_module(conn)
54+
end)
55+
end
56+
4657
test "returns an error if the given process is not a pool" do
4758
assert :error = DBConnection.connection_module(self())
4859
end

0 commit comments

Comments
 (0)