Skip to content

Commit 94fd83e

Browse files
authored
PYTHON-3814 add types to pyopenssl_context.py (#1341)
1 parent dc63c5d commit 94fd83e

File tree

2 files changed

+67
-38
lines changed

2 files changed

+67
-38
lines changed

pymongo/pyopenssl_context.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import time as _time
2323
from errno import EINTR as _EINTR
2424
from ipaddress import ip_address as _ip_address
25+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, Union
2526

2627
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate
2728
from OpenSSL import SSL as _SSL
@@ -39,6 +40,14 @@
3940
from pymongo.socket_checker import _errno_from_exception
4041
from pymongo.write_concern import validate_boolean
4142

43+
if TYPE_CHECKING:
44+
import socket
45+
from ssl import VerifyMode
46+
47+
from cryptography.x509 import Certificate
48+
49+
_T = TypeVar("_T")
50+
4251
try:
4352
import certifi
4453

@@ -73,7 +82,7 @@
7382

7483
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
7584
# not permitted for SNI hostname.
76-
def _is_ip_address(address):
85+
def _is_ip_address(address: Any) -> bool:
7786
try:
7887
_ip_address(address)
7988
return True
@@ -86,7 +95,7 @@ def _is_ip_address(address):
8695
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
8796

8897

89-
def _ragged_eof(exc):
98+
def _ragged_eof(exc: BaseException) -> bool:
9099
"""Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
91100
return exc.args == (-1, "Unexpected EOF")
92101

@@ -95,12 +104,14 @@ def _ragged_eof(exc):
95104
# https://github.com/pyca/pyopenssl/issues/176
96105
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
97106
class _sslConn(_SSL.Connection):
98-
def __init__(self, ctx, sock, suppress_ragged_eofs):
107+
def __init__(
108+
self, ctx: _SSL.Context, sock: Optional[socket.socket], suppress_ragged_eofs: bool
109+
):
99110
self.socket_checker = _SocketChecker()
100111
self.suppress_ragged_eofs = suppress_ragged_eofs
101112
super().__init__(ctx, sock)
102113

103-
def _call(self, call, *args, **kwargs):
114+
def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
104115
timeout = self.gettimeout()
105116
if timeout:
106117
start = _time.monotonic()
@@ -127,10 +138,10 @@ def _call(self, call, *args, **kwargs):
127138
raise _socket.timeout("timed out")
128139
continue
129140

130-
def do_handshake(self, *args, **kwargs):
141+
def do_handshake(self, *args: Any, **kwargs: Any) -> None:
131142
return self._call(super().do_handshake, *args, **kwargs)
132143

133-
def recv(self, *args, **kwargs):
144+
def recv(self, *args: Any, **kwargs: Any) -> bytes:
134145
try:
135146
return self._call(super().recv, *args, **kwargs)
136147
except _SSL.SysCallError as exc:
@@ -139,7 +150,7 @@ def recv(self, *args, **kwargs):
139150
return b""
140151
raise
141152

142-
def recv_into(self, *args, **kwargs):
153+
def recv_into(self, *args: Any, **kwargs: Any) -> int:
143154
try:
144155
return self._call(super().recv_into, *args, **kwargs)
145156
except _SSL.SysCallError as exc:
@@ -148,7 +159,7 @@ def recv_into(self, *args, **kwargs):
148159
return 0
149160
raise
150161

151-
def sendall(self, buf, flags=0):
162+
def sendall(self, buf: bytes, flags: int = 0) -> None: # type: ignore[override]
152163
view = memoryview(buf)
153164
total_length = len(buf)
154165
total_sent = 0
@@ -172,9 +183,9 @@ def sendall(self, buf, flags=0):
172183
class _CallbackData:
173184
"""Data class which is passed to the OCSP callback."""
174185

175-
def __init__(self):
176-
self.trusted_ca_certs = None
177-
self.check_ocsp_endpoint = None
186+
def __init__(self) -> None:
187+
self.trusted_ca_certs: Optional[List[Certificate]] = None
188+
self.check_ocsp_endpoint: Optional[bool] = None
178189
self.ocsp_response_cache = _OCSPCache()
179190

180191

@@ -185,7 +196,7 @@ class SSLContext:
185196

186197
__slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname")
187198

188-
def __init__(self, protocol):
199+
def __init__(self, protocol: int):
189200
self._protocol = protocol
190201
self._ctx = _SSL.Context(self._protocol)
191202
self._callback_data = _CallbackData()
@@ -198,66 +209,80 @@ def __init__(self, protocol):
198209
self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data)
199210

200211
@property
201-
def protocol(self):
212+
def protocol(self) -> int:
202213
"""The protocol version chosen when constructing the context.
203214
This attribute is read-only.
204215
"""
205216
return self._protocol
206217

207-
def __get_verify_mode(self):
218+
def __get_verify_mode(self) -> VerifyMode:
208219
"""Whether to try to verify other peers' certificates and how to
209220
behave if verification fails. This attribute must be one of
210221
ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
211222
"""
212223
return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()]
213224

214-
def __set_verify_mode(self, value):
225+
def __set_verify_mode(self, value: VerifyMode) -> None:
215226
"""Setter for verify_mode."""
216227

217-
def _cb(connobj, x509obj, errnum, errdepth, retcode):
228+
def _cb(
229+
connobj: _SSL.Connection,
230+
x509obj: _crypto.X509,
231+
errnum: int,
232+
errdepth: int,
233+
retcode: int,
234+
) -> bool:
218235
# It seems we don't need to do anything here. Twisted doesn't,
219236
# and OpenSSL's SSL_CTX_set_verify let's you pass NULL
220237
# for the callback option. It's weird that PyOpenSSL requires
221238
# this.
222-
return retcode
239+
# This is optional in pyopenssl >= 20 and can be removed once minimum
240+
# supported version is bumped
241+
# See: pyopenssl.org/en/latest/changelog.html#id47
242+
return bool(retcode)
223243

224244
self._ctx.set_verify(_VERIFY_MAP[value], _cb)
225245

226246
verify_mode = property(__get_verify_mode, __set_verify_mode)
227247

228-
def __get_check_hostname(self):
248+
def __get_check_hostname(self) -> bool:
229249
return self._check_hostname
230250

231-
def __set_check_hostname(self, value):
251+
def __set_check_hostname(self, value: Any) -> None:
232252
validate_boolean("check_hostname", value)
233253
self._check_hostname = value
234254

235255
check_hostname = property(__get_check_hostname, __set_check_hostname)
236256

237-
def __get_check_ocsp_endpoint(self):
257+
def __get_check_ocsp_endpoint(self) -> Optional[bool]:
238258
return self._callback_data.check_ocsp_endpoint
239259

240-
def __set_check_ocsp_endpoint(self, value):
260+
def __set_check_ocsp_endpoint(self, value: bool) -> None:
241261
validate_boolean("check_ocsp", value)
242262
self._callback_data.check_ocsp_endpoint = value
243263

244264
check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint)
245265

246-
def __get_options(self):
266+
def __get_options(self) -> None:
247267
# Calling set_options adds the option to the existing bitmask and
248268
# returns the new bitmask.
249269
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
250270
return self._ctx.set_options(0)
251271

252-
def __set_options(self, value):
272+
def __set_options(self, value: int) -> None:
253273
# Explcitly convert to int, since newer CPython versions
254274
# use enum.IntFlag for options. The values are the same
255275
# regardless of implementation.
256276
self._ctx.set_options(int(value))
257277

258278
options = property(__get_options, __set_options)
259279

260-
def load_cert_chain(self, certfile, keyfile=None, password=None):
280+
def load_cert_chain(
281+
self,
282+
certfile: Union[str, bytes],
283+
keyfile: Union[str, bytes, None] = None,
284+
password: Optional[str] = None,
285+
) -> None:
261286
"""Load a private key and the corresponding certificate. The certfile
262287
string must be the path to a single file in PEM format containing the
263288
certificate as well as any number of CA certificates needed to
@@ -270,28 +295,32 @@ def load_cert_chain(self, certfile, keyfile=None, password=None):
270295
# Password callback MUST be set first or it will be ignored.
271296
if password:
272297

273-
def _pwcb(max_length, prompt_twice, user_data):
298+
def _pwcb(max_length: int, prompt_twice: bool, user_data: bytes) -> bytes:
274299
# XXX:We could check the password length against what OpenSSL
275300
# tells us is the max, but we can't raise an exception, so...
276301
# warn?
302+
assert password is not None
277303
return password.encode("utf-8")
278304

279305
self._ctx.set_passwd_cb(_pwcb)
280306
self._ctx.use_certificate_chain_file(certfile)
281307
self._ctx.use_privatekey_file(keyfile or certfile)
282308
self._ctx.check_privatekey()
283309

284-
def load_verify_locations(self, cafile=None, capath=None):
310+
def load_verify_locations(
311+
self, cafile: Optional[str] = None, capath: Optional[str] = None
312+
) -> None:
285313
"""Load a set of "certification authority"(CA) certificates used to
286314
validate other peers' certificates when `~verify_mode` is other than
287315
ssl.CERT_NONE.
288316
"""
289317
self._ctx.load_verify_locations(cafile, capath)
290318
# Manually load the CA certs when get_verified_chain is not available (pyopenssl<20).
291319
if not hasattr(_SSL.Connection, "get_verified_chain"):
320+
assert cafile is not None
292321
self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile)
293322

294-
def _load_certifi(self):
323+
def _load_certifi(self) -> None:
295324
"""Attempt to load CA certs from certifi."""
296325
if _HAVE_CERTIFI:
297326
self.load_verify_locations(certifi.where())
@@ -303,7 +332,7 @@ def _load_certifi(self):
303332
"the tlsCAFile option"
304333
)
305334

306-
def _load_wincerts(self, store):
335+
def _load_wincerts(self, store: str) -> None:
307336
"""Attempt to load CA certs from Windows trust store."""
308337
cert_store = self._ctx.get_cert_store()
309338
oid = _stdlibssl.Purpose.SERVER_AUTH.oid
@@ -314,7 +343,7 @@ def _load_wincerts(self, store):
314343
_crypto.X509.from_cryptography(_load_der_x509_certificate(cert))
315344
)
316345

317-
def load_default_certs(self):
346+
def load_default_certs(self) -> None:
318347
"""A PyOpenSSL version of load_default_certs from CPython."""
319348
# PyOpenSSL is incapable of loading CA certs from Windows, and mostly
320349
# incapable on macOS.
@@ -330,7 +359,7 @@ def load_default_certs(self):
330359
self._load_certifi()
331360
self._ctx.set_default_verify_paths()
332361

333-
def set_default_verify_paths(self):
362+
def set_default_verify_paths(self) -> None:
334363
"""Specify that the platform provided CA certificates are to be used
335364
for verification purposes.
336365
"""
@@ -340,13 +369,13 @@ def set_default_verify_paths(self):
340369

341370
def wrap_socket(
342371
self,
343-
sock,
344-
server_side=False,
345-
do_handshake_on_connect=True,
346-
suppress_ragged_eofs=True,
347-
server_hostname=None,
348-
session=None,
349-
):
372+
sock: socket.socket,
373+
server_side: bool = False,
374+
do_handshake_on_connect: bool = True,
375+
suppress_ragged_eofs: bool = True,
376+
server_hostname: Optional[str] = None,
377+
session: Optional[_SSL.Session] = None,
378+
) -> _sslConn:
350379
"""Wrap an existing Python socket connection and return a TLS socket
351380
object.
352381
"""

tools/ocsptest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def check_ocsp(host, port, capath):
4242
s = socket.socket()
4343
s.connect((host, port))
4444
try:
45-
s = ctx.wrap_socket(s, server_hostname=host)
45+
s = ctx.wrap_socket(s, server_hostname=host) # type: ignore[assignment]
4646
finally:
4747
s.close()
4848

0 commit comments

Comments
 (0)