22
22
import time as _time
23
23
from errno import EINTR as _EINTR
24
24
from ipaddress import ip_address as _ip_address
25
+ from typing import TYPE_CHECKING , Any , Callable , List , Optional , TypeVar , Union
25
26
26
27
from cryptography .x509 import load_der_x509_certificate as _load_der_x509_certificate
27
28
from OpenSSL import SSL as _SSL
39
40
from pymongo .socket_checker import _errno_from_exception
40
41
from pymongo .write_concern import validate_boolean
41
42
43
+ if TYPE_CHECKING :
44
+ import socket
45
+ from ssl import VerifyMode
46
+
47
+ from cryptography .x509 import Certificate
48
+
49
+ _T = TypeVar ("_T" )
50
+
42
51
try :
43
52
import certifi
44
53
73
82
74
83
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
75
84
# not permitted for SNI hostname.
76
- def _is_ip_address (address ) :
85
+ def _is_ip_address (address : Any ) -> bool :
77
86
try :
78
87
_ip_address (address )
79
88
return True
@@ -86,7 +95,7 @@ def _is_ip_address(address):
86
95
BLOCKING_IO_ERRORS = (_SSL .WantReadError , _SSL .WantWriteError , _SSL .WantX509LookupError )
87
96
88
97
89
- def _ragged_eof (exc ) :
98
+ def _ragged_eof (exc : BaseException ) -> bool :
90
99
"""Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
91
100
return exc .args == (- 1 , "Unexpected EOF" )
92
101
@@ -95,12 +104,14 @@ def _ragged_eof(exc):
95
104
# https://github.com/pyca/pyopenssl/issues/176
96
105
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
97
106
class _sslConn (_SSL .Connection ):
98
- def __init__ (self , ctx , sock , suppress_ragged_eofs ):
107
+ def __init__ (
108
+ self , ctx : _SSL .Context , sock : Optional [socket .socket ], suppress_ragged_eofs : bool
109
+ ):
99
110
self .socket_checker = _SocketChecker ()
100
111
self .suppress_ragged_eofs = suppress_ragged_eofs
101
112
super ().__init__ (ctx , sock )
102
113
103
- def _call (self , call , * args , ** kwargs ) :
114
+ def _call (self , call : Callable [..., _T ], * args : Any , ** kwargs : Any ) -> _T :
104
115
timeout = self .gettimeout ()
105
116
if timeout :
106
117
start = _time .monotonic ()
@@ -127,10 +138,10 @@ def _call(self, call, *args, **kwargs):
127
138
raise _socket .timeout ("timed out" )
128
139
continue
129
140
130
- def do_handshake (self , * args , ** kwargs ) :
141
+ def do_handshake (self , * args : Any , ** kwargs : Any ) -> None :
131
142
return self ._call (super ().do_handshake , * args , ** kwargs )
132
143
133
- def recv (self , * args , ** kwargs ) :
144
+ def recv (self , * args : Any , ** kwargs : Any ) -> bytes :
134
145
try :
135
146
return self ._call (super ().recv , * args , ** kwargs )
136
147
except _SSL .SysCallError as exc :
@@ -139,7 +150,7 @@ def recv(self, *args, **kwargs):
139
150
return b""
140
151
raise
141
152
142
- def recv_into (self , * args , ** kwargs ) :
153
+ def recv_into (self , * args : Any , ** kwargs : Any ) -> int :
143
154
try :
144
155
return self ._call (super ().recv_into , * args , ** kwargs )
145
156
except _SSL .SysCallError as exc :
@@ -148,7 +159,7 @@ def recv_into(self, *args, **kwargs):
148
159
return 0
149
160
raise
150
161
151
- def sendall (self , buf , flags = 0 ):
162
+ def sendall (self , buf : bytes , flags : int = 0 ) -> None : # type: ignore[override]
152
163
view = memoryview (buf )
153
164
total_length = len (buf )
154
165
total_sent = 0
@@ -172,9 +183,9 @@ def sendall(self, buf, flags=0):
172
183
class _CallbackData :
173
184
"""Data class which is passed to the OCSP callback."""
174
185
175
- def __init__ (self ):
176
- self .trusted_ca_certs = None
177
- self .check_ocsp_endpoint = None
186
+ def __init__ (self ) -> None :
187
+ self .trusted_ca_certs : Optional [ List [ Certificate ]] = None
188
+ self .check_ocsp_endpoint : Optional [ bool ] = None
178
189
self .ocsp_response_cache = _OCSPCache ()
179
190
180
191
@@ -185,7 +196,7 @@ class SSLContext:
185
196
186
197
__slots__ = ("_protocol" , "_ctx" , "_callback_data" , "_check_hostname" )
187
198
188
- def __init__ (self , protocol ):
199
+ def __init__ (self , protocol : int ):
189
200
self ._protocol = protocol
190
201
self ._ctx = _SSL .Context (self ._protocol )
191
202
self ._callback_data = _CallbackData ()
@@ -198,66 +209,80 @@ def __init__(self, protocol):
198
209
self ._ctx .set_ocsp_client_callback (callback = _ocsp_callback , data = self ._callback_data )
199
210
200
211
@property
201
- def protocol (self ):
212
+ def protocol (self ) -> int :
202
213
"""The protocol version chosen when constructing the context.
203
214
This attribute is read-only.
204
215
"""
205
216
return self ._protocol
206
217
207
- def __get_verify_mode (self ):
218
+ def __get_verify_mode (self ) -> VerifyMode :
208
219
"""Whether to try to verify other peers' certificates and how to
209
220
behave if verification fails. This attribute must be one of
210
221
ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
211
222
"""
212
223
return _REVERSE_VERIFY_MAP [self ._ctx .get_verify_mode ()]
213
224
214
- def __set_verify_mode (self , value ) :
225
+ def __set_verify_mode (self , value : VerifyMode ) -> None :
215
226
"""Setter for verify_mode."""
216
227
217
- def _cb (connobj , x509obj , errnum , errdepth , retcode ):
228
+ def _cb (
229
+ connobj : _SSL .Connection ,
230
+ x509obj : _crypto .X509 ,
231
+ errnum : int ,
232
+ errdepth : int ,
233
+ retcode : int ,
234
+ ) -> bool :
218
235
# It seems we don't need to do anything here. Twisted doesn't,
219
236
# and OpenSSL's SSL_CTX_set_verify let's you pass NULL
220
237
# for the callback option. It's weird that PyOpenSSL requires
221
238
# this.
222
- return retcode
239
+ # This is optional in pyopenssl >= 20 and can be removed once minimum
240
+ # supported version is bumped
241
+ # See: pyopenssl.org/en/latest/changelog.html#id47
242
+ return bool (retcode )
223
243
224
244
self ._ctx .set_verify (_VERIFY_MAP [value ], _cb )
225
245
226
246
verify_mode = property (__get_verify_mode , __set_verify_mode )
227
247
228
- def __get_check_hostname (self ):
248
+ def __get_check_hostname (self ) -> bool :
229
249
return self ._check_hostname
230
250
231
- def __set_check_hostname (self , value ) :
251
+ def __set_check_hostname (self , value : Any ) -> None :
232
252
validate_boolean ("check_hostname" , value )
233
253
self ._check_hostname = value
234
254
235
255
check_hostname = property (__get_check_hostname , __set_check_hostname )
236
256
237
- def __get_check_ocsp_endpoint (self ):
257
+ def __get_check_ocsp_endpoint (self ) -> Optional [ bool ] :
238
258
return self ._callback_data .check_ocsp_endpoint
239
259
240
- def __set_check_ocsp_endpoint (self , value ) :
260
+ def __set_check_ocsp_endpoint (self , value : bool ) -> None :
241
261
validate_boolean ("check_ocsp" , value )
242
262
self ._callback_data .check_ocsp_endpoint = value
243
263
244
264
check_ocsp_endpoint = property (__get_check_ocsp_endpoint , __set_check_ocsp_endpoint )
245
265
246
- def __get_options (self ):
266
+ def __get_options (self ) -> None :
247
267
# Calling set_options adds the option to the existing bitmask and
248
268
# returns the new bitmask.
249
269
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
250
270
return self ._ctx .set_options (0 )
251
271
252
- def __set_options (self , value ) :
272
+ def __set_options (self , value : int ) -> None :
253
273
# Explcitly convert to int, since newer CPython versions
254
274
# use enum.IntFlag for options. The values are the same
255
275
# regardless of implementation.
256
276
self ._ctx .set_options (int (value ))
257
277
258
278
options = property (__get_options , __set_options )
259
279
260
- def load_cert_chain (self , certfile , keyfile = None , password = None ):
280
+ def load_cert_chain (
281
+ self ,
282
+ certfile : Union [str , bytes ],
283
+ keyfile : Union [str , bytes , None ] = None ,
284
+ password : Optional [str ] = None ,
285
+ ) -> None :
261
286
"""Load a private key and the corresponding certificate. The certfile
262
287
string must be the path to a single file in PEM format containing the
263
288
certificate as well as any number of CA certificates needed to
@@ -270,28 +295,32 @@ def load_cert_chain(self, certfile, keyfile=None, password=None):
270
295
# Password callback MUST be set first or it will be ignored.
271
296
if password :
272
297
273
- def _pwcb (max_length , prompt_twice , user_data ) :
298
+ def _pwcb (max_length : int , prompt_twice : bool , user_data : bytes ) -> bytes :
274
299
# XXX:We could check the password length against what OpenSSL
275
300
# tells us is the max, but we can't raise an exception, so...
276
301
# warn?
302
+ assert password is not None
277
303
return password .encode ("utf-8" )
278
304
279
305
self ._ctx .set_passwd_cb (_pwcb )
280
306
self ._ctx .use_certificate_chain_file (certfile )
281
307
self ._ctx .use_privatekey_file (keyfile or certfile )
282
308
self ._ctx .check_privatekey ()
283
309
284
- def load_verify_locations (self , cafile = None , capath = None ):
310
+ def load_verify_locations (
311
+ self , cafile : Optional [str ] = None , capath : Optional [str ] = None
312
+ ) -> None :
285
313
"""Load a set of "certification authority"(CA) certificates used to
286
314
validate other peers' certificates when `~verify_mode` is other than
287
315
ssl.CERT_NONE.
288
316
"""
289
317
self ._ctx .load_verify_locations (cafile , capath )
290
318
# Manually load the CA certs when get_verified_chain is not available (pyopenssl<20).
291
319
if not hasattr (_SSL .Connection , "get_verified_chain" ):
320
+ assert cafile is not None
292
321
self ._callback_data .trusted_ca_certs = _load_trusted_ca_certs (cafile )
293
322
294
- def _load_certifi (self ):
323
+ def _load_certifi (self ) -> None :
295
324
"""Attempt to load CA certs from certifi."""
296
325
if _HAVE_CERTIFI :
297
326
self .load_verify_locations (certifi .where ())
@@ -303,7 +332,7 @@ def _load_certifi(self):
303
332
"the tlsCAFile option"
304
333
)
305
334
306
- def _load_wincerts (self , store ) :
335
+ def _load_wincerts (self , store : str ) -> None :
307
336
"""Attempt to load CA certs from Windows trust store."""
308
337
cert_store = self ._ctx .get_cert_store ()
309
338
oid = _stdlibssl .Purpose .SERVER_AUTH .oid
@@ -314,7 +343,7 @@ def _load_wincerts(self, store):
314
343
_crypto .X509 .from_cryptography (_load_der_x509_certificate (cert ))
315
344
)
316
345
317
- def load_default_certs (self ):
346
+ def load_default_certs (self ) -> None :
318
347
"""A PyOpenSSL version of load_default_certs from CPython."""
319
348
# PyOpenSSL is incapable of loading CA certs from Windows, and mostly
320
349
# incapable on macOS.
@@ -330,7 +359,7 @@ def load_default_certs(self):
330
359
self ._load_certifi ()
331
360
self ._ctx .set_default_verify_paths ()
332
361
333
- def set_default_verify_paths (self ):
362
+ def set_default_verify_paths (self ) -> None :
334
363
"""Specify that the platform provided CA certificates are to be used
335
364
for verification purposes.
336
365
"""
@@ -340,13 +369,13 @@ def set_default_verify_paths(self):
340
369
341
370
def wrap_socket (
342
371
self ,
343
- sock ,
344
- server_side = False ,
345
- do_handshake_on_connect = True ,
346
- suppress_ragged_eofs = True ,
347
- server_hostname = None ,
348
- session = None ,
349
- ):
372
+ sock : socket . socket ,
373
+ server_side : bool = False ,
374
+ do_handshake_on_connect : bool = True ,
375
+ suppress_ragged_eofs : bool = True ,
376
+ server_hostname : Optional [ str ] = None ,
377
+ session : Optional [ _SSL . Session ] = None ,
378
+ ) -> _sslConn :
350
379
"""Wrap an existing Python socket connection and return a TLS socket
351
380
object.
352
381
"""
0 commit comments