Skip to content

Commit cf6d650

Browse files
committed
Refactor code a bit to reduce allocations and dynamic dispatches
Together with JuliaWeb/HTTP.jl#985, this reduces allocations on a typical request by about ~100. The main wins here are hard-coding TCPSocket as the `io` field of `SSLStream` (which we assume anyway), and changing the `geterror` function to a macro to avoid the closure boxing problem.
1 parent 6789aa0 commit cf6d650

File tree

1 file changed

+49
-53
lines changed

1 file changed

+49
-53
lines changed

src/ssl.jl

+49-53
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ mutable struct SSLStream <: IO
397397
ssl_context::SSLContext
398398
rbio::BIO
399399
wbio::BIO
400-
io::IO
400+
io::TCPSocket
401401
lock::ReentrantLock
402402
@static if VERSION < v"1.7"
403403
close_notify_received::Threads.Atomic{Bool}
@@ -407,7 +407,7 @@ else
407407
@atomic closed::Bool
408408
end
409409

410-
function SSLStream(ssl_context::SSLContext, io::IO)
410+
function SSLStream(ssl_context::SSLContext, io::TCPSocket)
411411
# Create a read and write BIOs.
412412
bio_read::BIO = BIO(io; finalize=false)
413413
bio_write::BIO = BIO(io; finalize=false)
@@ -422,7 +422,7 @@ end
422422
end
423423

424424
# backwards compat
425-
SSLStream(ssl_context::SSLContext, io::IO, ::IO) = SSLStream(ssl_context, io)
425+
SSLStream(ssl_context::SSLContext, io::TCPSocket, ::TCPSocket) = SSLStream(ssl_context, io)
426426
Base.getproperty(ssl::SSLStream, nm::Symbol) = nm === :bio_read_stream ? ssl : getfield(ssl, nm)
427427

428428
SSLStream(tcp::TCPSocket) = SSLStream(SSLContext(OpenSSL.TLSClientMethod()), tcp)
@@ -432,45 +432,43 @@ Base.isopen(ssl::SSLStream)::Bool = !@atomicget(ssl.closed)
432432
Base.iswritable(ssl::SSLStream)::Bool = isopen(ssl) && isopen(ssl.io)
433433
check_isopen(ssl::SSLStream, op) = isopen(ssl) || throw(Base.IOError("$op requires ssl to be open", 0))
434434

435-
function geterror(f, ssl::SSLStream)
436-
ret = f()
437-
if ret <= 0
438-
err = get_error(ssl.ssl, ret)
439-
if err == SSL_ERROR_ZERO_RETURN
440-
@atomicset ssl.close_notify_received = true
441-
elseif err == SSL_ERROR_NONE
442-
# pass
443-
elseif err == SSL_ERROR_WANT_READ
444-
return SSL_ERROR_WANT_READ
445-
elseif err == SSL_ERROR_WANT_WRITE
446-
return SSL_ERROR_WANT_WRITE
447-
else
448-
close(ssl, false)
449-
throw(Base.IOError(OpenSSLError(err).msg, 0))
435+
macro geterror(expr)
436+
esc(quote
437+
ret = $expr
438+
if ret <= 0
439+
err = get_error(ssl.ssl, ret)
440+
if err == SSL_ERROR_ZERO_RETURN
441+
@atomicset ssl.close_notify_received = true
442+
elseif err == SSL_ERROR_NONE
443+
# pass
444+
elseif err == SSL_ERROR_WANT_READ
445+
ret = SSL_ERROR_WANT_READ
446+
elseif err == SSL_ERROR_WANT_WRITE
447+
ret = SSL_ERROR_WANT_WRITE
448+
else
449+
close(ssl, false)
450+
throw(Base.IOError(OpenSSLError(err).msg, 0))
451+
end
450452
end
451-
end
452-
return ret
453+
end)
453454
end
454455

455456
function Base.unsafe_write(ssl::SSLStream, in_buffer::Ptr{UInt8}, in_length::UInt)
456457
check_isopen(ssl, "unsafe_write")
457-
return geterror(ssl) do
458-
ccall(
459-
(:SSL_write, libssl),
460-
Cint,
461-
(SSL, Ptr{Cvoid}, Cint),
462-
ssl.ssl,
463-
in_buffer,
464-
in_length)
465-
end
458+
return @geterror ccall(
459+
(:SSL_write, libssl),
460+
Cint,
461+
(SSL, Ptr{Cvoid}, Cint),
462+
ssl.ssl,
463+
in_buffer,
464+
in_length
465+
)
466466
end
467467

468468
function Sockets.connect(ssl::SSLStream; require_ssl_verification::Bool=true)
469469
while true
470470
check_isopen(ssl, "connect")
471-
ret = geterror(ssl) do
472-
ssl_connect(ssl.ssl)
473-
end
471+
@geterror ssl_connect(ssl.ssl)
474472
if (ret == 1 || ret == SSL_ERROR_NONE)
475473
break
476474
elseif ret == SSL_ERROR_WANT_READ
@@ -537,17 +535,16 @@ function Base.unsafe_read(ssl::SSLStream, buf::Ptr{UInt8}, nbytes::UInt)
537535
while nread < nbytes
538536
(!isopen(ssl) || eof(ssl)) && throw(EOFError())
539537
readbytes = Ref{Csize_t}()
540-
geterror(ssl) do
541-
ccall(
542-
(:SSL_read_ex, libssl),
543-
Cint,
544-
(SSL, Ptr{Int8}, Csize_t, Ptr{Csize_t}),
545-
ssl.ssl,
546-
buf + nread,
547-
nbytes - nread,
548-
readbytes)
549-
end
550-
nread += readbytes[]
538+
@geterror ccall(
539+
(:SSL_read_ex, libssl),
540+
Cint,
541+
(SSL, Ptr{Int8}, Csize_t, Ptr{Csize_t}),
542+
ssl.ssl,
543+
buf + nread,
544+
nbytes - nread,
545+
readbytes
546+
)
547+
nread += Int(readbytes[])
551548
end
552549
return nread
553550
end
@@ -600,19 +597,18 @@ function Base.eof(ssl::SSLStream)::Bool
600597
end
601598
# if we're here, we know there are unprocessed bytes,
602599
# so we call peek to force processing
603-
err = geterror(ssl) do
604-
ccall(
605-
(:SSL_peek, libssl),
606-
Cint,
607-
(SSL, Ref{UInt8}, Cint),
608-
ssl.ssl,
609-
PEEK_REF,
610-
1)
611-
end
600+
@geterror ccall(
601+
(:SSL_peek, libssl),
602+
Cint,
603+
(SSL, Ref{UInt8}, Cint),
604+
ssl.ssl,
605+
PEEK_REF,
606+
1
607+
)
612608
# if we get WANT_READ back, that means there were pending bytes
613609
# to be processed, but not a full record, so we need to wait
614610
# for additional bytes to come in before we can process
615-
err == SSL_ERROR_WANT_READ && eof(ssl.io)
611+
ret == SSL_ERROR_WANT_READ && eof(ssl.io)
616612
end
617613
end
618614
bytesavailable(ssl) > 0 && return false

0 commit comments

Comments
 (0)