Skip to content

Commit 15636a2

Browse files
committed
fix(ws_transport): Unit test on reading WS data byte by byte
Closes #14704 Closes espressif/esp-protocols#679
1 parent b8a7d96 commit 15636a2

File tree

2 files changed

+78
-33
lines changed

2 files changed

+78
-33
lines changed

components/tcp_transport/host_test/main/test_websocket_transport.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,52 @@ int mock_write_callback(esp_transport_handle_t transport, const char *request_se
102102
return len;
103103
}
104104

105-
// Callback function for mock_read
105+
// Callbacks for mocked poll_reed and read functions
106+
int mock_poll_read_callback(esp_transport_handle_t t, int timeout_ms, int num_call)
107+
{
108+
if (num_call) {
109+
return 0;
110+
}
111+
return 1;
112+
}
113+
106114
int mock_valid_read_callback(esp_transport_handle_t transport, char *buffer, int len, int timeout_ms, int num_call)
107115
{
116+
if (num_call) {
117+
return 0;
118+
}
108119
std::string websocket_response = make_response();
109120
std::memcpy(buffer, websocket_response.data(), websocket_response.size());
110121
return websocket_response.size();
111122
}
112123

124+
// Callback function for mock_read
125+
int mock_valid_read_fragmented_callback(esp_transport_handle_t t, char *buffer, int len, int timeout_ms, int num_call)
126+
{
127+
static int offset = 0;
128+
std::string websocket_response = make_response();
129+
if (buffer == nullptr) {
130+
return offset == websocket_response.size() ? 0 : 1;
131+
}
132+
int read_size = 1;
133+
if (offset == websocket_response.size()) {
134+
return 0;
135+
}
136+
std::memcpy(buffer, websocket_response.data() + offset, read_size);
137+
offset += read_size;
138+
return read_size;
139+
}
140+
141+
int mock_valid_poll_read_fragmented_callback(esp_transport_handle_t t, int timeout_ms, int num_call)
142+
{
143+
return mock_valid_read_fragmented_callback(t, nullptr, 0, 0, 0);
144+
}
145+
113146
}
114147

115-
void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read_callback) {
148+
void test_ws_connect(bool expect_valid_connection,
149+
CMOCK_mock_read_CALLBACK read_callback,
150+
CMOCK_mock_poll_read_CALLBACK poll_read_callback=mock_poll_read_callback) {
116151
constexpr static auto timeout = 50;
117152
constexpr static auto port = 8080;
118153
constexpr static auto host = "localhost";
@@ -128,15 +163,15 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read
128163

129164
SECTION("Successful connection and read data") {
130165
fmt::print("Attempting to connect to WebSocket\n");
131-
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
166+
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
132167
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
133168

134169
// Set the callback function for mock_write
135170
mock_write_Stub(mock_write_callback);
136171
mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK);
137172
// Set the callback function for mock_read
138173
mock_read_Stub(read_callback);
139-
mock_poll_read_ExpectAnyArgsAndReturn(1);
174+
mock_poll_read_Stub(poll_read_callback);
140175
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
141176
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
142177

@@ -150,7 +185,11 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read
150185
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
151186

152187
char buffer[WS_BUFFER_SIZE];
153-
int read_len = esp_transport_read(websocket_transport.get(), buffer, WS_BUFFER_SIZE, timeout);
188+
int read_len = 0;
189+
int partial_read;
190+
while ((partial_read = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout)) > 0 ) {
191+
read_len+= partial_read;
192+
}
154193
fmt::print("Read result: {}\n", read_len);
155194
REQUIRE(read_len > 0); // Ensure data is read
156195

@@ -166,6 +205,12 @@ TEST_CASE("WebSocket Transport Connection", "[websocket_transport]")
166205
test_ws_connect(true, mock_valid_read_callback);
167206
}
168207

208+
// Happy flow with fragmented reads byte by byte
209+
TEST_CASE("ws connect and reads by fragments", "[websocket_transport]")
210+
{
211+
test_ws_connect(true, mock_valid_read_fragmented_callback, mock_valid_poll_read_fragmented_callback);
212+
}
213+
169214
// Some corner cases where we expect the ws connection to fail
170215

171216
TEST_CASE("ws connect fails (0 len response)", "[websocket_transport]")

components/tcp_transport/transport_ws.c

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,34 +133,6 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len
133133
return to_read;
134134
}
135135

136-
static int esp_transport_read_exact_size(transport_ws_t *ws, char *buffer, int requested_len, int timeout_ms)
137-
{
138-
int total_read = 0;
139-
int len = requested_len;
140-
141-
while (len > 0) {
142-
int bytes_read = esp_transport_read_internal(ws, buffer, len, timeout_ms);
143-
144-
if (bytes_read < 0) {
145-
return bytes_read; // Return error from the underlying read
146-
}
147-
148-
if (bytes_read == 0) {
149-
// If we read 0 bytes, we return an error, since reading exact number of bytes resulted in a timeout operation
150-
ESP_LOGW(TAG, "Requested to read %d, actually read %d bytes", requested_len, total_read);
151-
return -1;
152-
}
153-
154-
// Update buffer and remaining length
155-
buffer += bytes_read;
156-
len -= bytes_read;
157-
total_read += bytes_read;
158-
159-
ESP_LOGV(TAG, "Read fragment of %d bytes", bytes_read);
160-
}
161-
return total_read;
162-
}
163-
164136
static char *trimwhitespace(char *str)
165137
{
166138
char *end;
@@ -495,6 +467,34 @@ static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int
495467
return rlen;
496468
}
497469

470+
static int esp_transport_read_exact_size(transport_ws_t *ws, char *buffer, int requested_len, int timeout_ms)
471+
{
472+
int total_read = 0;
473+
int len = requested_len;
474+
475+
while (len > 0) {
476+
int bytes_read = esp_transport_read_internal(ws, buffer, len, timeout_ms);
477+
478+
if (bytes_read < 0) {
479+
return bytes_read; // Return error from the underlying read
480+
}
481+
482+
if (bytes_read == 0) {
483+
// If we read 0 bytes, we return an error, since reading exact number of bytes resulted in a timeout operation
484+
ESP_LOGW(TAG, "Requested to read %d, actually read %d bytes", requested_len, total_read);
485+
return -1;
486+
}
487+
488+
// Update buffer and remaining length
489+
buffer += bytes_read;
490+
len -= bytes_read;
491+
total_read += bytes_read;
492+
493+
ESP_LOGV(TAG, "Read fragment of %d bytes", bytes_read);
494+
}
495+
return total_read;
496+
}
497+
498498

499499
/* Read and parse the WS header, determine length of payload */
500500
static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)

0 commit comments

Comments
 (0)