diff --git a/cores/esp32/USBMSC.cpp b/cores/esp32/USBMSC.cpp index 6d36117b886..eeaf3026535 100644 --- a/cores/esp32/USBMSC.cpp +++ b/cores/esp32/USBMSC.cpp @@ -33,6 +33,7 @@ extern "C" uint16_t tusb_msc_load_descriptor(uint8_t *dst, uint8_t *itf) { typedef struct { bool media_present; + bool is_writable; uint8_t vendor_id[8]; uint8_t product_id[16]; uint8_t product_rev[4]; @@ -179,11 +180,17 @@ int32_t tud_msc_scsi_cb(uint8_t lun, uint8_t const scsi_cmd[16], void *buffer, u return resplen; } +bool tud_msc_is_writable_cb(uint8_t lun) { + log_v("[%u]: %u", lun, msc_luns[lun].is_writable); + return msc_luns[lun].is_writable; // RAM disk is always ready +} + USBMSC::USBMSC() { if (MSC_ACTIVE_LUN < MSC_MAX_LUN) { _lun = MSC_ACTIVE_LUN; MSC_ACTIVE_LUN++; msc_luns[_lun].media_present = false; + msc_luns[_lun].is_writable = true; msc_luns[_lun].vendor_id[0] = 0; msc_luns[_lun].product_id[0] = 0; msc_luns[_lun].product_rev[0] = 0; @@ -213,6 +220,7 @@ bool USBMSC::begin(uint32_t block_count, uint16_t block_size) { void USBMSC::end() { msc_luns[_lun].media_present = false; + msc_luns[_lun].is_writable = false; msc_luns[_lun].vendor_id[0] = 0; msc_luns[_lun].product_id[0] = 0; msc_luns[_lun].product_rev[0] = 0; @@ -247,6 +255,10 @@ void USBMSC::onWrite(msc_write_cb cb) { msc_luns[_lun].write = cb; } +void USBMSC::isWritable(bool is_writable) { + msc_luns[_lun].is_writable = is_writable; +} + void USBMSC::mediaPresent(bool media_present) { msc_luns[_lun].media_present = media_present; } diff --git a/cores/esp32/USBMSC.h b/cores/esp32/USBMSC.h index e9d41e0b7f3..454aca3520a 100644 --- a/cores/esp32/USBMSC.h +++ b/cores/esp32/USBMSC.h @@ -44,6 +44,7 @@ class USBMSC { void productID(const char *pid); //max 16 chars void productRevision(const char *ver); //max 4 chars void mediaPresent(bool media_present); + void isWritable(bool is_writable); void onStartStop(msc_start_stop_cb cb); void onRead(msc_read_cb cb); void onWrite(msc_write_cb cb); diff --git a/cores/esp32/esp32-hal-tinyusb.c b/cores/esp32/esp32-hal-tinyusb.c index d772b3e9e56..1ade6e68020 100644 --- a/cores/esp32/esp32-hal-tinyusb.c +++ b/cores/esp32/esp32-hal-tinyusb.c @@ -450,6 +450,10 @@ __attribute__((weak)) int32_t tud_msc_write10_cb(uint8_t lun, uint32_t lba, uint __attribute__((weak)) int32_t tud_msc_scsi_cb(uint8_t lun, uint8_t const scsi_cmd[16], void *buffer, uint16_t bufsize) { return -1; } +__attribute__((weak)) bool tud_msc_is_writable_cb(uint8_t lun) { + return false; +} + #endif /* diff --git a/libraries/NetworkClientSecure/src/NetworkClientSecure.cpp b/libraries/NetworkClientSecure/src/NetworkClientSecure.cpp index 0ab7168ebab..1dc2e75bbce 100644 --- a/libraries/NetworkClientSecure/src/NetworkClientSecure.cpp +++ b/libraries/NetworkClientSecure/src/NetworkClientSecure.cpp @@ -32,8 +32,11 @@ NetworkClientSecure::NetworkClientSecure() { _connected = false; _timeout = 30000; // Same default as ssl_client - sslclient = new sslclient_context; - ssl_init(sslclient); + sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) { + stop_ssl_socket(sslclient); + delete sslclient; + }); + ssl_init(sslclient.get()); sslclient->socket = -1; sslclient->handshake_timeout = 120000; _use_insecure = false; @@ -53,8 +56,11 @@ NetworkClientSecure::NetworkClientSecure(int sock) { _lastReadTimeout = 0; _lastWriteTimeout = 0; - sslclient = new sslclient_context; - ssl_init(sslclient); + sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) { + stop_ssl_socket(sslclient); + delete sslclient; + }); + ssl_init(sslclient.get()); sslclient->socket = sock; sslclient->handshake_timeout = 120000; @@ -71,20 +77,10 @@ NetworkClientSecure::NetworkClientSecure(int sock) { _alpn_protos = NULL; } -NetworkClientSecure::~NetworkClientSecure() { - stop(); - delete sslclient; -} - -NetworkClientSecure &NetworkClientSecure::operator=(const NetworkClientSecure &other) { - stop(); - sslclient->socket = other.sslclient->socket; - _connected = other._connected; - return *this; -} +NetworkClientSecure::~NetworkClientSecure() {} void NetworkClientSecure::stop() { - stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key); + stop_ssl_socket(sslclient.get()); _connected = false; _peek = -1; @@ -130,10 +126,10 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *CA } int NetworkClientSecure::connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key) { - int ret = start_ssl_client(sslclient, ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos); + int ret = start_ssl_client(sslclient.get(), ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos); if (ret >= 0 && !_stillinPlainStart) { - ret = ssl_starttls_handshake(sslclient); + ret = ssl_starttls_handshake(sslclient.get()); } else { log_i("Actual TLS start postponed."); } @@ -153,7 +149,7 @@ int NetworkClientSecure::startTLS() { int ret = 1; if (_stillinPlainStart) { log_i("startTLS: starting TLS/SSL on this dplain connection"); - ret = ssl_starttls_handshake(sslclient); + ret = ssl_starttls_handshake(sslclient.get()); if (ret < 0) { log_e("startTLS: %d", ret); stop(); @@ -178,7 +174,7 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *ps return 0; } - int ret = start_ssl_client(sslclient, address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos); + int ret = start_ssl_client(sslclient.get(), address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos); _lastError = ret; if (ret < 0) { log_e("start_ssl_client: connect failed %d", ret); @@ -213,7 +209,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) { } if (_stillinPlainStart) { - return send_net_data(sslclient, buf, size); + return send_net_data(sslclient.get(), buf, size); } if (_lastWriteTimeout != _timeout) { @@ -224,7 +220,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) { _lastWriteTimeout = _timeout; } } - int res = send_ssl_data(sslclient, buf, size); + int res = send_ssl_data(sslclient.get(), buf, size); if (res < 0) { log_e("Closing connection on failed write"); stop(); @@ -235,7 +231,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) { int NetworkClientSecure::read(uint8_t *buf, size_t size) { if (_stillinPlainStart) { - return get_net_receive(sslclient, buf, size); + return get_net_receive(sslclient.get(), buf, size); } if (_lastReadTimeout != _timeout) { @@ -268,7 +264,7 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) { buf++; peeked = 1; } - res = get_ssl_receive(sslclient, buf, size); + res = get_ssl_receive(sslclient.get(), buf, size); if (res < 0) { log_e("Closing connection on failed read"); @@ -280,14 +276,14 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) { int NetworkClientSecure::available() { if (_stillinPlainStart) { - return peek_net_receive(sslclient, 0); + return peek_net_receive(sslclient.get(), 0); } int peeked = (_peek >= 0), res = -1; if (!_connected) { return peeked; } - res = data_to_read(sslclient); + res = data_to_read(sslclient.get()); if (res < 0 && !_stillinPlainStart) { log_e("Closing connection on failed available check"); @@ -346,7 +342,7 @@ bool NetworkClientSecure::verify(const char *fp, const char *domain_name) { return false; } - return verify_ssl_fingerprint(sslclient, fp, domain_name); + return verify_ssl_fingerprint(sslclient.get(), fp, domain_name); } char *NetworkClientSecure::_streamLoad(Stream &stream, size_t size) { diff --git a/libraries/NetworkClientSecure/src/NetworkClientSecure.h b/libraries/NetworkClientSecure/src/NetworkClientSecure.h index 17240820b77..50520e072ef 100644 --- a/libraries/NetworkClientSecure/src/NetworkClientSecure.h +++ b/libraries/NetworkClientSecure/src/NetworkClientSecure.h @@ -24,10 +24,11 @@ #include "IPAddress.h" #include "Network.h" #include "ssl_client.h" +#include class NetworkClientSecure : public NetworkClient { protected: - sslclient_context *sslclient; + std::shared_ptr sslclient; int _lastError = 0; int _peek = -1; @@ -97,14 +98,14 @@ class NetworkClientSecure : public NetworkClient { return mbedtls_ssl_get_peer_cert(&sslclient->ssl_ctx); }; bool getFingerprintSHA256(uint8_t sha256_result[32]) { - return get_peer_fingerprint(sslclient, sha256_result); + return get_peer_fingerprint(sslclient.get(), sha256_result); }; int fd() const; operator bool() { return connected(); } - NetworkClientSecure &operator=(const NetworkClientSecure &other); + bool operator==(const bool value) { return bool() == value; } diff --git a/libraries/NetworkClientSecure/src/ssl_client.cpp b/libraries/NetworkClientSecure/src/ssl_client.cpp index fd0b8aa4eb8..b33782b71f2 100644 --- a/libraries/NetworkClientSecure/src/ssl_client.cpp +++ b/libraries/NetworkClientSecure/src/ssl_client.cpp @@ -344,7 +344,7 @@ int ssl_starttls_handshake(sslclient_context *ssl_client) { return ssl_client->socket; } -void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) { +void stop_ssl_socket(sslclient_context *ssl_client) { log_v("Cleaning SSL connection."); if (ssl_client->socket >= 0) { diff --git a/libraries/NetworkClientSecure/src/ssl_client.h b/libraries/NetworkClientSecure/src/ssl_client.h index 5690529c112..2309996bb14 100644 --- a/libraries/NetworkClientSecure/src/ssl_client.h +++ b/libraries/NetworkClientSecure/src/ssl_client.h @@ -34,7 +34,7 @@ int start_ssl_client( const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos ); int ssl_starttls_handshake(sslclient_context *ssl_client); -void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); +void stop_ssl_socket(sslclient_context *ssl_client); int data_to_read(sslclient_context *ssl_client); int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len); int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length); diff --git a/libraries/USB/examples/USBMSC/USBMSC.ino b/libraries/USB/examples/USBMSC/USBMSC.ino index 47c6084580d..532691ba63d 100644 --- a/libraries/USB/examples/USBMSC/USBMSC.ino +++ b/libraries/USB/examples/USBMSC/USBMSC.ino @@ -152,7 +152,10 @@ void setup() { MSC.onStartStop(onStartStop); MSC.onRead(onRead); MSC.onWrite(onWrite); + MSC.mediaPresent(true); + MSC.isWritable(true); // true if writable, false if read-only + MSC.begin(DISK_SECTOR_COUNT, DISK_SECTOR_SIZE); USB.begin(); }