Skip to content

Commit

Permalink
Merge branch 'master' into idf-release/v5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
me-no-dev authored May 13, 2024
2 parents 9b22d50 + ea27a98 commit 4f10e9a
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 32 deletions.
12 changes: 12 additions & 0 deletions cores/esp32/USBMSC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ extern "C" uint16_t tusb_msc_load_descriptor(uint8_t *dst, uint8_t *itf) {

typedef struct {
bool media_present;
bool is_writable;
uint8_t vendor_id[8];
uint8_t product_id[16];
uint8_t product_rev[4];
Expand Down Expand Up @@ -179,11 +180,17 @@ int32_t tud_msc_scsi_cb(uint8_t lun, uint8_t const scsi_cmd[16], void *buffer, u
return resplen;
}

bool tud_msc_is_writable_cb(uint8_t lun) {
log_v("[%u]: %u", lun, msc_luns[lun].is_writable);
return msc_luns[lun].is_writable; // RAM disk is always ready
}

USBMSC::USBMSC() {
if (MSC_ACTIVE_LUN < MSC_MAX_LUN) {
_lun = MSC_ACTIVE_LUN;
MSC_ACTIVE_LUN++;
msc_luns[_lun].media_present = false;
msc_luns[_lun].is_writable = true;
msc_luns[_lun].vendor_id[0] = 0;
msc_luns[_lun].product_id[0] = 0;
msc_luns[_lun].product_rev[0] = 0;
Expand Down Expand Up @@ -213,6 +220,7 @@ bool USBMSC::begin(uint32_t block_count, uint16_t block_size) {

void USBMSC::end() {
msc_luns[_lun].media_present = false;
msc_luns[_lun].is_writable = false;
msc_luns[_lun].vendor_id[0] = 0;
msc_luns[_lun].product_id[0] = 0;
msc_luns[_lun].product_rev[0] = 0;
Expand Down Expand Up @@ -247,6 +255,10 @@ void USBMSC::onWrite(msc_write_cb cb) {
msc_luns[_lun].write = cb;
}

void USBMSC::isWritable(bool is_writable) {
msc_luns[_lun].is_writable = is_writable;
}

void USBMSC::mediaPresent(bool media_present) {
msc_luns[_lun].media_present = media_present;
}
Expand Down
1 change: 1 addition & 0 deletions cores/esp32/USBMSC.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class USBMSC {
void productID(const char *pid); //max 16 chars
void productRevision(const char *ver); //max 4 chars
void mediaPresent(bool media_present);
void isWritable(bool is_writable);
void onStartStop(msc_start_stop_cb cb);
void onRead(msc_read_cb cb);
void onWrite(msc_write_cb cb);
Expand Down
4 changes: 4 additions & 0 deletions cores/esp32/esp32-hal-tinyusb.c
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ __attribute__((weak)) int32_t tud_msc_write10_cb(uint8_t lun, uint32_t lba, uint
__attribute__((weak)) int32_t tud_msc_scsi_cb(uint8_t lun, uint8_t const scsi_cmd[16], void *buffer, uint16_t bufsize) {
return -1;
}
__attribute__((weak)) bool tud_msc_is_writable_cb(uint8_t lun) {
return false;
}

#endif

/*
Expand Down
50 changes: 23 additions & 27 deletions libraries/NetworkClientSecure/src/NetworkClientSecure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ NetworkClientSecure::NetworkClientSecure() {
_connected = false;
_timeout = 30000; // Same default as ssl_client

sslclient = new sslclient_context;
ssl_init(sslclient);
sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) {
stop_ssl_socket(sslclient);
delete sslclient;
});
ssl_init(sslclient.get());
sslclient->socket = -1;
sslclient->handshake_timeout = 120000;
_use_insecure = false;
Expand All @@ -53,8 +56,11 @@ NetworkClientSecure::NetworkClientSecure(int sock) {
_lastReadTimeout = 0;
_lastWriteTimeout = 0;

sslclient = new sslclient_context;
ssl_init(sslclient);
sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) {
stop_ssl_socket(sslclient);
delete sslclient;
});
ssl_init(sslclient.get());
sslclient->socket = sock;
sslclient->handshake_timeout = 120000;

Expand All @@ -71,20 +77,10 @@ NetworkClientSecure::NetworkClientSecure(int sock) {
_alpn_protos = NULL;
}

NetworkClientSecure::~NetworkClientSecure() {
stop();
delete sslclient;
}

NetworkClientSecure &NetworkClientSecure::operator=(const NetworkClientSecure &other) {
stop();
sslclient->socket = other.sslclient->socket;
_connected = other._connected;
return *this;
}
NetworkClientSecure::~NetworkClientSecure() {}

void NetworkClientSecure::stop() {
stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key);
stop_ssl_socket(sslclient.get());

_connected = false;
_peek = -1;
Expand Down Expand Up @@ -130,10 +126,10 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *CA
}

int NetworkClientSecure::connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key) {
int ret = start_ssl_client(sslclient, ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);
int ret = start_ssl_client(sslclient.get(), ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);

if (ret >= 0 && !_stillinPlainStart) {
ret = ssl_starttls_handshake(sslclient);
ret = ssl_starttls_handshake(sslclient.get());
} else {
log_i("Actual TLS start postponed.");
}
Expand All @@ -153,7 +149,7 @@ int NetworkClientSecure::startTLS() {
int ret = 1;
if (_stillinPlainStart) {
log_i("startTLS: starting TLS/SSL on this dplain connection");
ret = ssl_starttls_handshake(sslclient);
ret = ssl_starttls_handshake(sslclient.get());
if (ret < 0) {
log_e("startTLS: %d", ret);
stop();
Expand All @@ -178,7 +174,7 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *ps
return 0;
}

int ret = start_ssl_client(sslclient, address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
int ret = start_ssl_client(sslclient.get(), address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
_lastError = ret;
if (ret < 0) {
log_e("start_ssl_client: connect failed %d", ret);
Expand Down Expand Up @@ -213,7 +209,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
}

if (_stillinPlainStart) {
return send_net_data(sslclient, buf, size);
return send_net_data(sslclient.get(), buf, size);
}

if (_lastWriteTimeout != _timeout) {
Expand All @@ -224,7 +220,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
_lastWriteTimeout = _timeout;
}
}
int res = send_ssl_data(sslclient, buf, size);
int res = send_ssl_data(sslclient.get(), buf, size);
if (res < 0) {
log_e("Closing connection on failed write");
stop();
Expand All @@ -235,7 +231,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {

int NetworkClientSecure::read(uint8_t *buf, size_t size) {
if (_stillinPlainStart) {
return get_net_receive(sslclient, buf, size);
return get_net_receive(sslclient.get(), buf, size);
}

if (_lastReadTimeout != _timeout) {
Expand Down Expand Up @@ -268,7 +264,7 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) {
buf++;
peeked = 1;
}
res = get_ssl_receive(sslclient, buf, size);
res = get_ssl_receive(sslclient.get(), buf, size);

if (res < 0) {
log_e("Closing connection on failed read");
Expand All @@ -280,14 +276,14 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) {

int NetworkClientSecure::available() {
if (_stillinPlainStart) {
return peek_net_receive(sslclient, 0);
return peek_net_receive(sslclient.get(), 0);
}

int peeked = (_peek >= 0), res = -1;
if (!_connected) {
return peeked;
}
res = data_to_read(sslclient);
res = data_to_read(sslclient.get());

if (res < 0 && !_stillinPlainStart) {
log_e("Closing connection on failed available check");
Expand Down Expand Up @@ -346,7 +342,7 @@ bool NetworkClientSecure::verify(const char *fp, const char *domain_name) {
return false;
}

return verify_ssl_fingerprint(sslclient, fp, domain_name);
return verify_ssl_fingerprint(sslclient.get(), fp, domain_name);
}

char *NetworkClientSecure::_streamLoad(Stream &stream, size_t size) {
Expand Down
7 changes: 4 additions & 3 deletions libraries/NetworkClientSecure/src/NetworkClientSecure.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
#include "IPAddress.h"
#include "Network.h"
#include "ssl_client.h"
#include <memory>

class NetworkClientSecure : public NetworkClient {
protected:
sslclient_context *sslclient;
std::shared_ptr<sslclient_context> sslclient;

int _lastError = 0;
int _peek = -1;
Expand Down Expand Up @@ -97,14 +98,14 @@ class NetworkClientSecure : public NetworkClient {
return mbedtls_ssl_get_peer_cert(&sslclient->ssl_ctx);
};
bool getFingerprintSHA256(uint8_t sha256_result[32]) {
return get_peer_fingerprint(sslclient, sha256_result);
return get_peer_fingerprint(sslclient.get(), sha256_result);
};
int fd() const;

operator bool() {
return connected();
}
NetworkClientSecure &operator=(const NetworkClientSecure &other);

bool operator==(const bool value) {
return bool() == value;
}
Expand Down
2 changes: 1 addition & 1 deletion libraries/NetworkClientSecure/src/ssl_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ int ssl_starttls_handshake(sslclient_context *ssl_client) {
return ssl_client->socket;
}

void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) {
void stop_ssl_socket(sslclient_context *ssl_client) {
log_v("Cleaning SSL connection.");

if (ssl_client->socket >= 0) {
Expand Down
2 changes: 1 addition & 1 deletion libraries/NetworkClientSecure/src/ssl_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int start_ssl_client(
const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos
);
int ssl_starttls_handshake(sslclient_context *ssl_client);
void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key);
void stop_ssl_socket(sslclient_context *ssl_client);
int data_to_read(sslclient_context *ssl_client);
int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len);
int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length);
Expand Down
3 changes: 3 additions & 0 deletions libraries/USB/examples/USBMSC/USBMSC.ino
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ void setup() {
MSC.onStartStop(onStartStop);
MSC.onRead(onRead);
MSC.onWrite(onWrite);

MSC.mediaPresent(true);
MSC.isWritable(true); // true if writable, false if read-only

MSC.begin(DISK_SECTOR_COUNT, DISK_SECTOR_SIZE);
USB.begin();
}
Expand Down

0 comments on commit 4f10e9a

Please sign in to comment.