Skip to content

Commit 1d843b9

Browse files
Drvinickrobinson251quinnj
authored
Parametrize Connections to avoid type instabilities (#983)
* Parametrize Connections to avoid type instabilities * Fixes and PR feedback * Add a POOLS cache dict * Julia 1.6 compat * Update src/ConnectionPool.jl Co-authored-by: Nick Robinson <[email protected]> * Update src/ConnectionPool.jl Co-authored-by: Nick Robinson <[email protected]> Co-authored-by: Jacob Quinn <[email protected]>
1 parent 5057ad1 commit 1d843b9

File tree

1 file changed

+35
-22
lines changed

1 file changed

+35
-22
lines changed

src/ConnectionPool.jl

+35-22
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Methods are provided for `eof`, `readavailable`,
1010
This allows the `Connection` object to act as a proxy for the
1111
`TCPSocket` or `SSLContext` that it wraps.
1212
13-
The [`POOL`](@ref) is used to manage connection pooling. Connections
13+
[`POOLS`](@ref) are used to manage connection pooling. Connections
1414
are identified by their host, port, whether they require
1515
ssl verification, and whether they are a client or server connection.
1616
If a subsequent request matches these properties of a previous connection
@@ -39,7 +39,7 @@ using .ConnectionPools
3939
"""
4040
Connection
4141
42-
A `TCPSocket` or `SSLContext` connection to a HTTP `host` and `port`.
42+
A `Sockets.TCPSocket`, `MbedTLS.SSLContext` or `OpenSSL.SSLStream` connection to a HTTP `host` and `port`.
4343
4444
Fields:
4545
- `host::String`
@@ -49,7 +49,7 @@ Fields:
4949
- `peerip`, remote IP adress (used for debug/log messages).
5050
- `peerport`, remote TCP port number (used for debug/log messages).
5151
- `localport`, local TCP port number (used for debug messages).
52-
- `io::T`, the `TCPSocket` or `SSLContext.
52+
- `io::T`, the `Sockets.TCPSocket`, `MbedTLS.SSLContext` or `OpenSSL.SSLStream`.
5353
- `clientconnection::Bool`, whether the Connection was created from client code (as opposed to server code)
5454
- `buffer::IOBuffer`, left over bytes read from the connection after
5555
the end of a response header (or chunksize). These bytes are usually
@@ -58,15 +58,15 @@ Fields:
5858
- `readable`, whether the Connection object is readable
5959
- `writable`, whether the Connection object is writable
6060
"""
61-
mutable struct Connection <: IO
61+
mutable struct Connection{IO_t <: IO} <: IO
6262
host::String
6363
port::String
6464
idle_timeout::Int
6565
require_ssl_verification::Bool
6666
peerip::IPAddr # for debugging/logging
6767
peerport::UInt16 # for debugging/logging
6868
localport::UInt16 # debug only
69-
io::IO
69+
io::IO_t
7070
clientconnection::Bool
7171
buffer::IOBuffer
7272
timestamp::Float64
@@ -89,8 +89,8 @@ connectionkey(x::Connection) = (typeof(x.io), x.host, x.port, x.require_ssl_veri
8989

9090
Connection(host::AbstractString, port::AbstractString,
9191
idle_timeout::Int,
92-
require_ssl_verification::Bool, io::IO, client=true) =
93-
Connection(host, port, idle_timeout,
92+
require_ssl_verification::Bool, io::T, client=true) where {T}=
93+
Connection{T}(host, port, idle_timeout,
9494
require_ssl_verification,
9595
safe_getpeername(io)..., localport(io),
9696
io, client, PipeBuffer(), time(), false, false, IOBuffer(), nothing)
@@ -325,22 +325,35 @@ function purge(c::Connection)
325325
@ensure bytesavailable(c) == 0
326326
end
327327

328+
const TCP_POOL = Pool(Connection{Sockets.TCPSocket})
329+
const MbedTLS_SSL_POOL = Pool(Connection{MbedTLS.SSLContext})
330+
const OpenSSL_SSL_POOL = Pool(Connection{OpenSSL.SSLStream})
328331
"""
329-
closeall()
332+
POOLS
330333
331-
Close all connections in`pool`.
334+
A dict of global connection pools keeping track of active connections, split by their IO type.
332335
"""
333-
function closeall()
334-
ConnectionPools.reset!(POOL)
335-
return
336+
const POOLS = Dict{DataType,Pool}(
337+
Sockets.TCPSocket => TCP_POOL,
338+
MbedTLS.SSLContext => MbedTLS_SSL_POOL,
339+
OpenSSL.SSLStream => OpenSSL_SSL_POOL,
340+
)
341+
getpool(::Type{Sockets.TCPSocket}) = TCP_POOL
342+
getpool(::Type{MbedTLS.SSLContext}) = MbedTLS_SSL_POOL
343+
getpool(::Type{OpenSSL.SSLStream}) = OpenSSL_SSL_POOL
344+
# Fallback for custom connection io types
345+
# to opt out from locking, define your own `Pool` and add a `getpool` method for your IO type
346+
const POOLS_LOCK = Threads.ReentrantLock()
347+
function getpool(::Type{T}) where {T}
348+
Base.@lock POOLS_LOCK get!(() -> Pool(Connection{T}), POOLS, T)::Pool{Connection{T}}
336349
end
337350

338351
"""
339-
POOL
352+
closeall()
340353
341-
Global connection pool keeping track of active connections.
354+
Close all connections in `POOLS`.
342355
"""
343-
const POOL = Pool(Connection)
356+
closeall() = foreach(ConnectionPools.reset!, values(POOLS))
344357

345358
"""
346359
newconnection(type, host, port) -> Connection
@@ -355,9 +368,9 @@ function newconnection(::Type{T},
355368
forcenew::Bool=false,
356369
idle_timeout=typemax(Int),
357370
require_ssl_verification::Bool=NetworkOptions.verify_host(host, "SSL"),
358-
kw...)::Connection where {T <: IO}
371+
kw...) where {T <: IO}
359372
return acquire(
360-
POOL,
373+
getpool(T),
361374
(T, host, port, require_ssl_verification, true);
362375
max_concurrent_connections=Int(connection_limit),
363376
forcenew=forcenew,
@@ -370,8 +383,8 @@ function newconnection(::Type{T},
370383
end
371384
end
372385

373-
releaseconnection(c::Connection, reuse) =
374-
release(POOL, connectionkey(c), c; return_for_reuse=reuse)
386+
releaseconnection(c::Connection{T}, reuse) where {T} =
387+
release(getpool(T), connectionkey(c), c; return_for_reuse=reuse)
375388

376389
function keepalive!(tcp)
377390
@debugv 2 "setting keepalive on tcp socket"
@@ -524,7 +537,7 @@ function sslupgrade(::Type{IOType}, c::Connection,
524537
host::AbstractString;
525538
require_ssl_verification::Bool=NetworkOptions.verify_host(host, "SSL"),
526539
readtimeout::Int=0,
527-
kw...)::Connection where {IOType}
540+
kw...)::Connection{IOType} where {IOType}
528541
# initiate the upgrade to SSL
529542
# if the upgrade fails, an error will be thrown and the original c will be closed
530543
# in ConnectionRequest
@@ -538,9 +551,9 @@ function sslupgrade(::Type{IOType}, c::Connection,
538551
# success, now we turn it into a new Connection
539552
conn = Connection(host, "", 0, require_ssl_verification, tls)
540553
# release the "old" one, but don't allow reuse since we're hijacking the socket
541-
release(POOL, connectionkey(c), c; return_for_reuse=false)
554+
release(getpool(IOType), connectionkey(c), c; return_for_reuse=false)
542555
# and return the new one
543-
return acquire(POOL, connectionkey(conn), conn)
556+
return acquire(getpool(IOType), connectionkey(conn), conn)
544557
end
545558

546559
function Base.show(io::IO, c::Connection)

0 commit comments

Comments
 (0)