36
36
#include "py/runtime.h"
37
37
#include "py/stream.h"
38
38
#include "py/objstr.h"
39
+ #include "py/reader.h"
40
+ #include "extmod/vfs.h"
39
41
40
42
// mbedtls_time_t
41
43
#include "mbedtls/platform.h"
46
48
#include "mbedtls/ctr_drbg.h"
47
49
#include "mbedtls/debug.h"
48
50
#include "mbedtls/error.h"
51
+ #if MBEDTLS_VERSION_NUMBER >= 0x03000000
52
+ #include "mbedtls/build_info.h"
53
+ #else
54
+ #include "mbedtls/version.h"
55
+ #endif
49
56
50
57
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
51
58
@@ -59,6 +66,7 @@ typedef struct _mp_obj_ssl_context_t {
59
66
mbedtls_x509_crt cert ;
60
67
mbedtls_pk_context pkey ;
61
68
int authmode ;
69
+ int * ciphersuites ;
62
70
} mp_obj_ssl_context_t ;
63
71
64
72
// This corresponds to an SSLSocket object.
@@ -75,12 +83,32 @@ typedef struct _mp_obj_ssl_socket_t {
75
83
STATIC const mp_obj_type_t ssl_context_type ;
76
84
STATIC const mp_obj_type_t ssl_socket_type ;
77
85
86
+ STATIC const MP_DEFINE_STR_OBJ (mbedtls_version_obj , MBEDTLS_VERSION_STRING_FULL );
87
+
78
88
STATIC mp_obj_t ssl_socket_make_new (mp_obj_ssl_context_t * ssl_context , mp_obj_t sock ,
79
89
bool server_side , bool do_handshake_on_connect , mp_obj_t server_hostname );
80
90
81
91
/******************************************************************************/
82
92
// Helper functions.
83
93
94
+ STATIC mp_obj_t read_file (mp_obj_t self_in ) {
95
+ // file = open(args[0], "rb")
96
+ mp_obj_t f_args [2 ] = {
97
+ self_in ,
98
+ MP_OBJ_NEW_QSTR (MP_QSTR_rb ),
99
+ };
100
+ mp_obj_t file = mp_vfs_open (2 , & f_args [0 ], (mp_map_t * )& mp_const_empty_map );
101
+
102
+ // data = file.read()
103
+ mp_obj_t dest [2 ];
104
+ mp_load_method (file , MP_QSTR_read , dest );
105
+ mp_obj_t data = mp_call_method_n_kw (0 , 0 , dest );
106
+
107
+ // file.close()
108
+ mp_stream_close (file );
109
+ return data ;
110
+ }
111
+
84
112
#ifdef MBEDTLS_DEBUG_C
85
113
STATIC void mbedtls_debug (void * ctx , int level , const char * file , int line , const char * str ) {
86
114
(void )ctx ;
@@ -162,6 +190,7 @@ STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
162
190
mbedtls_x509_crt_init (& self -> cacert );
163
191
mbedtls_x509_crt_init (& self -> cert );
164
192
mbedtls_pk_init (& self -> pkey );
193
+ self -> ciphersuites = NULL ;
165
194
166
195
#ifdef MBEDTLS_DEBUG_C
167
196
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
@@ -236,6 +265,54 @@ STATIC mp_obj_t ssl_context___del__(mp_obj_t self_in) {
236
265
STATIC MP_DEFINE_CONST_FUN_OBJ_1 (ssl_context___del___obj , ssl_context___del__ );
237
266
#endif
238
267
268
+ // SSLContext.get_ciphers()
269
+ STATIC mp_obj_t ssl_context_get_ciphers (mp_obj_t self_in ) {
270
+ mp_obj_t list = mp_obj_new_list (0 , NULL );
271
+ for (const int * cipher_list = mbedtls_ssl_list_ciphersuites (); * cipher_list ; ++ cipher_list ) {
272
+ const char * cipher_name = mbedtls_ssl_get_ciphersuite_name (* cipher_list );
273
+ mp_obj_list_append (list , MP_OBJ_FROM_PTR (mp_obj_new_str (cipher_name , strlen (cipher_name ))));
274
+ cipher_list ++ ;
275
+ if (!* cipher_list ) {
276
+ break ;
277
+ }
278
+ }
279
+ return list ;
280
+ }
281
+ STATIC MP_DEFINE_CONST_FUN_OBJ_1 (ssl_context_get_ciphers_obj , ssl_context_get_ciphers );
282
+
283
+ // SSLContext.set_ciphers(ciphersuite)
284
+ STATIC mp_obj_t ssl_context_set_ciphers (mp_obj_t self_in , mp_obj_t ciphersuite ) {
285
+ mp_obj_ssl_context_t * ssl_context = MP_OBJ_TO_PTR (self_in );
286
+
287
+ // Check that ciphersuite is a list or tuple.
288
+ size_t len = 0 ;
289
+ mp_obj_t * ciphers ;
290
+ mp_obj_get_array (ciphersuite , & len , & ciphers );
291
+ if (len == 0 ) {
292
+ mbedtls_raise_error (MBEDTLS_ERR_SSL_BAD_CONFIG );
293
+ }
294
+
295
+ // Parse list of ciphers.
296
+ ssl_context -> ciphersuites = m_new (int , len + 1 );
297
+ for (int i = 0 , n = len ; i < n ; i ++ ) {
298
+ if (ciphers [i ] != mp_const_none ) {
299
+ const char * ciphername = mp_obj_str_get_str (ciphers [i ]);
300
+ const int id = mbedtls_ssl_get_ciphersuite_id (ciphername );
301
+ ssl_context -> ciphersuites [i ] = id ;
302
+ if (id == 0 ) {
303
+ mbedtls_raise_error (MBEDTLS_ERR_SSL_BAD_CONFIG );
304
+ }
305
+ }
306
+ }
307
+ ssl_context -> ciphersuites [len + 1 ] = 0 ;
308
+
309
+ // Configure ciphersuite.
310
+ mbedtls_ssl_conf_ciphersuites (& ssl_context -> conf , (const int * )ssl_context -> ciphersuites );
311
+
312
+ return mp_const_none ;
313
+ }
314
+ STATIC MP_DEFINE_CONST_FUN_OBJ_2 (ssl_context_set_ciphers_obj , ssl_context_set_ciphers );
315
+
239
316
STATIC void ssl_context_load_key (mp_obj_ssl_context_t * self , mp_obj_t key_obj , mp_obj_t cert_obj ) {
240
317
size_t key_len ;
241
318
const byte * key = (const byte * )mp_obj_str_get_data (key_obj , & key_len );
@@ -264,6 +341,30 @@ STATIC void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, m
264
341
}
265
342
}
266
343
344
+ // SSLContext.load_cert_chain(certfile, keyfile)
345
+ STATIC mp_obj_t ssl_context_load_cert_chain (mp_obj_t self_in , mp_obj_t certfile , mp_obj_t keyfile ) {
346
+ mp_obj_ssl_context_t * self = MP_OBJ_TO_PTR (self_in );
347
+ mp_obj_t pkey ;
348
+ mp_obj_t cert ;
349
+ if (certfile != mp_const_none ) {
350
+ // check if key is a string/path
351
+ if (!(mp_obj_is_type (keyfile , & mp_type_bytes ))) {
352
+ pkey = read_file (keyfile );
353
+ } else {
354
+ pkey = keyfile ;
355
+ }
356
+ // check if cert is a string/path
357
+ if (!(mp_obj_is_type (certfile , & mp_type_bytes ))) {
358
+ cert = read_file (certfile );
359
+ } else {
360
+ cert = certfile ;
361
+ }
362
+ ssl_context_load_key (self , pkey , cert );
363
+ }
364
+ return mp_const_none ;
365
+ }
366
+ STATIC MP_DEFINE_CONST_FUN_OBJ_3 (ssl_context_load_cert_chain_obj , ssl_context_load_cert_chain );
367
+
267
368
STATIC void ssl_context_load_cadata (mp_obj_ssl_context_t * self , mp_obj_t cadata_obj ) {
268
369
size_t cacert_len ;
269
370
const byte * cacert = (const byte * )mp_obj_str_get_data (cadata_obj , & cacert_len );
@@ -276,6 +377,30 @@ STATIC void ssl_context_load_cadata(mp_obj_ssl_context_t *self, mp_obj_t cadata_
276
377
mbedtls_ssl_conf_ca_chain (& self -> conf , & self -> cacert , NULL );
277
378
}
278
379
380
+ // SSLContext.load_verify_locations(cafile=None, *, cadata=None)
381
+ STATIC mp_obj_t ssl_context_load_verify_locations (size_t n_args , const mp_obj_t * pos_args ,
382
+ mp_map_t * kw_args ) {
383
+
384
+ static const mp_arg_t allowed_args [] = {
385
+ { MP_QSTR_cafile , MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } },
386
+ { MP_QSTR_cadata , MP_ARG_KW_ONLY | MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } },
387
+ };
388
+
389
+ mp_obj_ssl_context_t * self = MP_OBJ_TO_PTR (pos_args [0 ]);
390
+ mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
391
+ mp_arg_parse_all (n_args - 1 , pos_args + 1 , kw_args , MP_ARRAY_SIZE (allowed_args ), allowed_args , args );
392
+ // cafile
393
+ if (args [0 ].u_obj != mp_const_none ) {
394
+ ssl_context_load_cadata (self , read_file (args [0 ].u_obj ));
395
+ }
396
+ // cadata
397
+ if (args [1 ].u_obj != mp_const_none ) {
398
+ ssl_context_load_cadata (self , args [1 ].u_obj );
399
+ }
400
+ return mp_const_none ;
401
+ }
402
+ STATIC MP_DEFINE_CONST_FUN_OBJ_KW (ssl_context_load_verify_locations_obj , 1 , ssl_context_load_verify_locations );
403
+
279
404
STATIC mp_obj_t ssl_context_wrap_socket (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
280
405
enum { ARG_server_side , ARG_do_handshake_on_connect , ARG_server_hostname };
281
406
static const mp_arg_t allowed_args [] = {
@@ -300,6 +425,10 @@ STATIC const mp_rom_map_elem_t ssl_context_locals_dict_table[] = {
300
425
#if MICROPY_PY_SSL_FINALISER
301
426
{ MP_ROM_QSTR (MP_QSTR___del__ ), MP_ROM_PTR (& ssl_context___del___obj ) },
302
427
#endif
428
+ { MP_ROM_QSTR (MP_QSTR_get_ciphers ), MP_ROM_PTR (& ssl_context_get_ciphers_obj )},
429
+ { MP_ROM_QSTR (MP_QSTR_set_ciphers ), MP_ROM_PTR (& ssl_context_set_ciphers_obj )},
430
+ { MP_ROM_QSTR (MP_QSTR_load_cert_chain ), MP_ROM_PTR (& ssl_context_load_cert_chain_obj )},
431
+ { MP_ROM_QSTR (MP_QSTR_load_verify_locations ), MP_ROM_PTR (& ssl_context_load_verify_locations_obj )},
303
432
{ MP_ROM_QSTR (MP_QSTR_wrap_socket ), MP_ROM_PTR (& ssl_context_wrap_socket_obj ) },
304
433
};
305
434
STATIC MP_DEFINE_CONST_DICT (ssl_context_locals_dict , ssl_context_locals_dict_table );
@@ -369,6 +498,8 @@ STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
369
498
o -> last_error = 0 ;
370
499
371
500
int ret ;
501
+ uint32_t flags = 0 ;
502
+
372
503
mbedtls_ssl_init (& o -> ssl );
373
504
374
505
ret = mbedtls_ssl_setup (& o -> ssl , & ssl_context -> conf );
@@ -382,6 +513,11 @@ STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
382
513
if (ret != 0 ) {
383
514
goto cleanup ;
384
515
}
516
+ } else if (ssl_context -> authmode == MBEDTLS_SSL_VERIFY_REQUIRED && server_side == false) {
517
+
518
+ o -> sock = MP_OBJ_NULL ;
519
+ mbedtls_ssl_free (& o -> ssl );
520
+ mp_raise_ValueError (MP_ERROR_TEXT ("CERT_REQUIRED requires server_hostname" ));
385
521
}
386
522
387
523
mbedtls_ssl_set_bio (& o -> ssl , & o -> sock , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
@@ -398,8 +534,23 @@ STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
398
534
return MP_OBJ_FROM_PTR (o );
399
535
400
536
cleanup :
537
+ if (ret == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED ) {
538
+ flags = mbedtls_ssl_get_verify_result (& o -> ssl );
539
+ }
540
+
401
541
o -> sock = MP_OBJ_NULL ;
402
542
mbedtls_ssl_free (& o -> ssl );
543
+
544
+ if (ret == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED ) {
545
+ char xcbuf [256 ];
546
+ int ret_info = mbedtls_x509_crt_verify_info (xcbuf , sizeof (xcbuf ), "\n" , flags );
547
+ // The length of the string written (not including the terminated nul byte),
548
+ // or a negative err code.
549
+ if (ret_info > 0 ) {
550
+ mp_raise_msg_varg (& mp_type_ValueError , MP_ERROR_TEXT ("%s" ), xcbuf );
551
+ }
552
+ }
553
+
403
554
mbedtls_raise_error (ret );
404
555
}
405
556
@@ -416,6 +567,17 @@ STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) {
416
567
}
417
568
STATIC MP_DEFINE_CONST_FUN_OBJ_2 (mod_ssl_getpeercert_obj , mod_ssl_getpeercert );
418
569
570
+ STATIC mp_obj_t mod_ssl_cipher (mp_obj_t o_in ) {
571
+ mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
572
+ const char * cipher_suite = mbedtls_ssl_get_ciphersuite (& o -> ssl );
573
+ const char * tls_version = mbedtls_ssl_get_version (& o -> ssl );
574
+ mp_obj_t tuple [2 ] = {mp_obj_new_str (cipher_suite , strlen (cipher_suite )),
575
+ mp_obj_new_str (tls_version , strlen (tls_version ))};
576
+
577
+ return mp_obj_new_tuple (2 , tuple );
578
+ }
579
+ STATIC MP_DEFINE_CONST_FUN_OBJ_1 (mod_ssl_cipher_obj , mod_ssl_cipher );
580
+
419
581
STATIC mp_uint_t socket_read (mp_obj_t o_in , void * buf , mp_uint_t size , int * errcode ) {
420
582
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
421
583
o -> poll_mask = 0 ;
@@ -565,6 +727,7 @@ STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
565
727
{ MP_ROM_QSTR (MP_QSTR_ioctl ), MP_ROM_PTR (& mp_stream_ioctl_obj ) },
566
728
#endif
567
729
{ MP_ROM_QSTR (MP_QSTR_getpeercert ), MP_ROM_PTR (& mod_ssl_getpeercert_obj ) },
730
+ { MP_ROM_QSTR (MP_QSTR_cipher ), MP_ROM_PTR (& mod_ssl_cipher_obj ) },
568
731
};
569
732
STATIC MP_DEFINE_CONST_DICT (ssl_socket_locals_dict , ssl_socket_locals_dict_table );
570
733
@@ -645,6 +808,7 @@ STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
645
808
{ MP_ROM_QSTR (MP_QSTR_SSLContext ), MP_ROM_PTR (& ssl_context_type ) },
646
809
647
810
// Constants.
811
+ { MP_ROM_QSTR (MP_QSTR_MBEDTLS_VERSION ), MP_ROM_PTR (& mbedtls_version_obj )},
648
812
{ MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_CLIENT ), MP_ROM_INT (MBEDTLS_SSL_IS_CLIENT ) },
649
813
{ MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_SERVER ), MP_ROM_INT (MBEDTLS_SSL_IS_SERVER ) },
650
814
{ MP_ROM_QSTR (MP_QSTR_CERT_NONE ), MP_ROM_INT (MBEDTLS_SSL_VERIFY_NONE ) },
0 commit comments