diff --git a/communication/inc/dtls_message_channel.h b/communication/inc/dtls_message_channel.h index 823ff49733..9523d20b77 100644 --- a/communication/inc/dtls_message_channel.h +++ b/communication/inc/dtls_message_channel.h @@ -40,6 +40,8 @@ namespace particle namespace protocol { +class Protocol; + /** * Please centralize this somewhere else! */ @@ -79,7 +81,6 @@ class DTLSMessageChannel: public BufferMessageChannel int (*restore)(void* data, size_t max_length, uint8_t type, void* reserved); uint32_t (*calculate_crc)(const uint8_t* data, uint32_t length); - void (*notify_client_messages_processed)(void* reserved); }; private: @@ -92,6 +93,7 @@ class DTLSMessageChannel: public BufferMessageChannel mbedtls_pk_context pkey; mbedtls_timing_delay_context timer; Callbacks callbacks; + Protocol* protocol; uint8_t* server_public; uint16_t server_public_len; uint32_t keys_checksum; @@ -123,13 +125,14 @@ class DTLSMessageChannel: public BufferMessageChannel void reset_session(); public: - DTLSMessageChannel() : + explicit DTLSMessageChannel(Protocol* protocol) : ssl_context(), conf(), clicert(), pkey(), timer(), callbacks(), + protocol(protocol), server_public(nullptr), server_public_len(0), keys_checksum(0), @@ -167,11 +170,7 @@ class DTLSMessageChannel: public BufferMessageChannel virtual ProtocolError command(Command cmd, void* arg=nullptr) override; - virtual void notify_client_messages_processed() override { - if (callbacks.notify_client_messages_processed) { - callbacks.notify_client_messages_processed(nullptr); - } - } + virtual void notify_client_messages_processed() override; virtual AppStateDescriptor cached_app_state_descriptor() const override; diff --git a/communication/inc/message_channel.h b/communication/inc/message_channel.h index e584b3c226..c63430725e 100644 --- a/communication/inc/message_channel.h +++ b/communication/inc/message_channel.h @@ -240,6 +240,11 @@ struct MessageChannel : public Channel */ virtual void notify_client_messages_processed()=0; + /** + * Returns `true` if there are client messages being processed at the moment, or `false` otherwise. + */ + virtual bool has_pending_client_messages() const = 0; + /** * Get a descriptor of the cached application state. * @@ -262,7 +267,7 @@ struct MessageChannel : public Channel class AbstractMessageChannel : public MessageChannel { public: - void set_debug_enabled(bool enabled) override { + void set_debug_enabled(bool /* enabled */) override { } }; diff --git a/communication/inc/protocol.h b/communication/inc/protocol.h index ced338aac3..72c1d355d1 100644 --- a/communication/inc/protocol.h +++ b/communication/inc/protocol.h @@ -624,10 +624,12 @@ class Protocol virtual int get_describe_data(spark_protocol_describe_data* data, void* reserved); - virtual int get_status(protocol_status* status) const = 0; + int get_status(protocol_status* status) const; void notify_message_complete(message_id_t msg_id, CoAPCode::Enum responseCode); + virtual void notify_client_messages_processed(); // Declared as virtual for mocking in unit tests + /** * Retrieves the next token. */ diff --git a/communication/src/coap_channel.h b/communication/src/coap_channel.h index c30d810bf4..314b2030dd 100644 --- a/communication/src/coap_channel.h +++ b/communication/src/coap_channel.h @@ -25,7 +25,9 @@ #include "service_debug.h" #include "communication_diagnostic.h" + #include +#include namespace particle { @@ -52,8 +54,10 @@ class CoAPChannel : public T } public: - CoAPChannel(message_id_t msg_seed=0) : message_id(msg_seed) - { + template + explicit CoAPChannel(ArgsT&&... args) : + T(std::forward(args)...), + message_id(0) { } /** @@ -554,8 +558,15 @@ class CoAPReliableChannel : public T DelegateChannel delegateChannel; public: + template + explicit CoAPReliableChannel(ArgsT&&... args) : + CoAPReliableChannel(M(), std::forward(args)...) { + } - CoAPReliableChannel(M m=0) : millis(m) { + template + explicit CoAPReliableChannel(M m, ArgsT&&... args) : + T(std::forward(args)...), + millis(m) { delegateChannel.init(this); } @@ -622,6 +633,11 @@ class CoAPReliableChannel : public T return receive(msg, true); } + bool has_pending_client_messages() const override + { + return client.has_messages(); + } + /** * Pulls messages from the message channel */ @@ -637,10 +653,6 @@ class CoAPReliableChannel : public T return client.has_messages() || server.has_unacknowledged_requests(); } - bool has_unacknowledged_client_requests() const { - return client.has_messages(); - } - /** * Pulls messages from the channel and stores it in a message store for * reliable receipt and retransmission. diff --git a/communication/src/description.cpp b/communication/src/description.cpp index 48bf17d6d3..93b5f3abff 100644 --- a/communication/src/description.cpp +++ b/communication/src/description.cpp @@ -393,6 +393,9 @@ ProtocolError Description::receiveAckOrRst(const Message& msg, int* descFlags) { if (!reqQueue_.isEmpty()) { CHECK_PROTOCOL(sendNextRequest(reqQueue_.takeFirst())); } + if (!activeReq_.has_value() && reqQueue_.isEmpty()) { + proto_->notify_client_messages_processed(); + } *descFlags = flags; } } else { diff --git a/communication/src/description.h b/communication/src/description.h index 97ef20e10c..420ea99b15 100644 --- a/communication/src/description.h +++ b/communication/src/description.h @@ -53,6 +53,8 @@ class Description { ProtocolError serialize(Appender* appender, int descFlags); + bool hasPendingClientRequests() const; + void reset(); private: @@ -96,6 +98,10 @@ class Description { system_tick_t millis() const; }; +inline bool Description::hasPendingClientRequests() const { + return activeReq_.has_value() || !reqQueue_.isEmpty(); +} + } // namespace protocol } // namespace particle diff --git a/communication/src/dtls_message_channel.cpp b/communication/src/dtls_message_channel.cpp index 96802b3cae..7504d3f988 100644 --- a/communication/src/dtls_message_channel.cpp +++ b/communication/src/dtls_message_channel.cpp @@ -588,6 +588,10 @@ ProtocolError DTLSMessageChannel::command(Command command, void* arg) return NO_ERROR; } +void DTLSMessageChannel::notify_client_messages_processed() { + protocol->notify_client_messages_processed(); +} + AppStateDescriptor DTLSMessageChannel::cached_app_state_descriptor() const { return sessionPersist.app_state_descriptor(); diff --git a/communication/src/dtls_protocol.cpp b/communication/src/dtls_protocol.cpp index b77e8cccd4..b5ed491d40 100644 --- a/communication/src/dtls_protocol.cpp +++ b/communication/src/dtls_protocol.cpp @@ -26,9 +26,6 @@ void DTLSProtocol::init(const char *id, channelCallbacks.save = callbacks.save; channelCallbacks.restore = callbacks.restore; } - if (offsetof(SparkCallbacks, notify_client_messages_processed) + sizeof(SparkCallbacks::notify_client_messages_processed) <= callbacks.size) { - channelCallbacks.notify_client_messages_processed = callbacks.notify_client_messages_processed; - } // TODO: Ideally, the next token value should be stored in the session data mbedtls_default_rng(nullptr, &next_token, sizeof(next_token)); diff --git a/communication/src/dtls_protocol.h b/communication/src/dtls_protocol.h index 2834bba4d8..178de3b467 100644 --- a/communication/src/dtls_protocol.h +++ b/communication/src/dtls_protocol.h @@ -55,7 +55,11 @@ class DTLSProtocol : public Protocol // todo - this a duplicate of LightSSLProtocol - factor out - DTLSProtocol() : Protocol(channel) {} + DTLSProtocol() : + Protocol(channel), + channel(this), + device_id() { + } void init(const char *id, const SparkKeys &keys, @@ -120,16 +124,6 @@ class DTLSProtocol : public Protocol } } - int get_status(protocol_status* status) const override { - SPARK_ASSERT(status); - status->flags = 0; - if (channel.has_unacknowledged_client_requests()) { - status->flags |= PROTOCOL_STATUS_HAS_PENDING_CLIENT_MESSAGES; - } - return NO_ERROR; - } - - /** * Ensures that all outstanding sent coap messages have been acknowledged. */ diff --git a/communication/src/protocol.cpp b/communication/src/protocol.cpp index 3f41a4c84d..acac2c7f2d 100644 --- a/communication/src/protocol.cpp +++ b/communication/src/protocol.cpp @@ -804,6 +804,22 @@ int Protocol::get_describe_data(spark_protocol_describe_data* data, void* reserv return 0; } +int Protocol::get_status(protocol_status* status) const { + SPARK_ASSERT(status); + status->flags = 0; + if (channel.has_pending_client_messages() || description.hasPendingClientRequests()) { + status->flags |= PROTOCOL_STATUS_HAS_PENDING_CLIENT_MESSAGES; + } + return ProtocolError::NO_ERROR; +} + +void Protocol::notify_client_messages_processed() { + if (callbacks.notify_client_messages_processed && !channel.has_pending_client_messages() && + !description.hasPendingClientRequests()) { // Ensure there's no pending blockwise requests + callbacks.notify_client_messages_processed(nullptr /* reserved */); + } +} + size_t Protocol::get_max_transmit_message_size() const { if (!max_transmit_message_size) { diff --git a/test/unit_tests/communication/coap_reliability.cpp b/test/unit_tests/communication/coap_reliability.cpp index 9fd85d76ef..03dbb3af32 100644 --- a/test/unit_tests/communication/coap_reliability.cpp +++ b/test/unit_tests/communication/coap_reliability.cpp @@ -1209,7 +1209,7 @@ SCENARIO("notify_client_messages_processed() is invoked when all client messages THEN("the callback is invoked only once") { Verify(Method(channelMock, notify_client_messages_processed)).Once(); - REQUIRE_FALSE(channel.has_unacknowledged_client_requests()); + REQUIRE_FALSE(channel.has_pending_client_messages()); } } @@ -1229,7 +1229,7 @@ SCENARIO("notify_client_messages_processed() is invoked when all client messages THEN("the callback is not invoked") { Verify(Method(channelMock, notify_client_messages_processed)).Never(); - REQUIRE(channel.has_unacknowledged_client_requests()); + REQUIRE(channel.has_pending_client_messages()); } } @@ -1247,7 +1247,7 @@ SCENARIO("notify_client_messages_processed() is invoked when all client messages THEN("the callback is invoked only once") { Verify(Method(channelMock, notify_client_messages_processed)).Once(); - REQUIRE_FALSE(channel.has_unacknowledged_client_requests()); + REQUIRE_FALSE(channel.has_pending_client_messages()); } } } diff --git a/test/unit_tests/communication/description.cpp b/test/unit_tests/communication/description.cpp index 54801d7a46..aa6b0194a2 100644 --- a/test/unit_tests/communication/description.cpp +++ b/test/unit_tests/communication/description.cpp @@ -541,4 +541,39 @@ TEST_CASE("Description") { CHECK(m.option(CoapOption::BLOCK2).toUInt() == BlockOption().index(0).more(true)); d.sendMessage(CoapMessage().type(CoapType::ACK).code(CoapCode::EMPTY).id(m.id())); } + + SECTION("notifies the protocol layer when all client requests have been processed") { + Mock proto(*d.protocol()); + When(Method(proto, notify_client_messages_processed)).AlwaysReturn(); + When(Method(cb, appendSystemInfo)).Do([](appender_fn append, void* arg, void* reserved) { + auto s = std::string(PROTOCOL_BUFFER_SIZE, 'a'); + append(arg, (const uint8_t*)s.data(), s.size()); + return true; + }); + When(Method(cb, appendAppInfo)).Do([](appender_fn append, void* arg, void* reserved) { + auto s = std::string(BLOCK_SIZE, 'b'); + append(arg, (const uint8_t*)s.data(), s.size()); + return true; + }); + CHECK(!d.get()->hasPendingClientRequests()); + // Send a blockwise request to the server + d.get()->sendRequest(DescriptionType::DESCRIBE_SYSTEM); + CHECK(d.get()->hasPendingClientRequests()); + // Receive and acknowledge the first block + auto m = d.receiveMessage(); + d.sendMessage(CoapMessage().type(CoapType::ACK).code(CoapCode::EMPTY).id(m.id())); + // Send another request to the server (a regular one) + d.get()->sendRequest(DescriptionType::DESCRIBE_APPLICATION); + // Receive the second block of the first request + m = d.receiveMessage(); + CHECK(d.get()->hasPendingClientRequests()); + Verify(Method(proto, notify_client_messages_processed)).Never(); + // Acknowledge the second block + d.sendMessage(CoapMessage().type(CoapType::ACK).code(CoapCode::EMPTY).id(m.id())); + CHECK(!d.get()->hasPendingClientRequests()); + Verify(Method(proto, notify_client_messages_processed)).Once(); + // Receive and acknowledge the second request + m = d.receiveMessage(); + d.sendMessage(CoapMessage().type(CoapType::ACK).code(CoapCode::EMPTY).id(m.id())); + } } diff --git a/test/unit_tests/communication/forward_message_channel.h b/test/unit_tests/communication/forward_message_channel.h index 4a7141056a..b321fe9da6 100644 --- a/test/unit_tests/communication/forward_message_channel.h +++ b/test/unit_tests/communication/forward_message_channel.h @@ -78,6 +78,11 @@ class ForwardMessageChannel : public MessageChannel channel->notify_client_messages_processed(); } + virtual bool has_pending_client_messages() const override + { + return channel->has_pending_client_messages(); + } + virtual AppStateDescriptor cached_app_state_descriptor() const override { return AppStateDescriptor(); diff --git a/test/unit_tests/communication/util/coap_message_channel.h b/test/unit_tests/communication/util/coap_message_channel.h index a2a3a4a98d..3f3a3a9d0b 100644 --- a/test/unit_tests/communication/util/coap_message_channel.h +++ b/test/unit_tests/communication/util/coap_message_channel.h @@ -44,7 +44,7 @@ class CoapMessageChannel: public BufferMessageChannel { // Returns true if there's a message received from the device bool hasMessages() const; - // Reimplemented from AbstractMessageChannel + // Reimplemented from MessageChannel ProtocolError send(Message& msg) override; ProtocolError receive(Message& msg) override; ProtocolError command(Command cmd, void* arg) override; @@ -52,6 +52,7 @@ class CoapMessageChannel: public BufferMessageChannel { ProtocolError establish() override; ProtocolError notify_established() override; void notify_client_messages_processed() override; + bool has_pending_client_messages() const override; AppStateDescriptor cached_app_state_descriptor() const override; void reset() override; @@ -112,6 +113,10 @@ inline ProtocolError CoapMessageChannel::notify_established() { inline void CoapMessageChannel::notify_client_messages_processed() { } +inline bool CoapMessageChannel::has_pending_client_messages() const { + return false; +} + inline AppStateDescriptor CoapMessageChannel::cached_app_state_descriptor() const { return AppStateDescriptor(); } diff --git a/test/unit_tests/communication/util/protocol_stub.h b/test/unit_tests/communication/util/protocol_stub.h index 8dfe110974..935b1ed62a 100644 --- a/test/unit_tests/communication/util/protocol_stub.h +++ b/test/unit_tests/communication/util/protocol_stub.h @@ -41,7 +41,6 @@ class ProtocolStub: public Protocol { void init(const char* id, const SparkKeys& keys, const SparkCallbacks& cb, const SparkDescriptor& desc) override; int command(ProtocolCommands::Enum cmd, uint32_t val, const void* data) override; size_t build_hello(Message& msg, uint16_t flags) override; - int get_status(protocol_status* status) const override; private: DescriptorCallbacks desc_; @@ -72,10 +71,6 @@ inline size_t ProtocolStub::build_hello(Message& msg, uint16_t flags) { return 0; } -inline int ProtocolStub::get_status(protocol_status* status) const { - return 0; -} - } // namespace test } // namespace protocol diff --git a/user/tests/integration/communication/functions/functions.cpp b/user/tests/integration/communication/functions/functions.cpp index 0351dd3ab2..d30f29134d 100644 --- a/user/tests/integration/communication/functions/functions.cpp +++ b/user/tests/integration/communication/functions/functions.cpp @@ -72,5 +72,5 @@ test(06_register_many_functions) { } Particle.connect(); waitUntil(Particle.connected); - delay(6000); // Give the system some time to send a blockwise Describe message + delay(3000); } diff --git a/user/tests/integration/communication/variables/variables.cpp b/user/tests/integration/communication/variables/variables.cpp index a9bf9221db..ce94c69ba8 100644 --- a/user/tests/integration/communication/variables/variables.cpp +++ b/user/tests/integration/communication/variables/variables.cpp @@ -133,5 +133,5 @@ test(07_register_many_variables) { } Particle.connect(); waitUntil(Particle.connected); - delay(6000); // Give the system some time to send a blockwise Describe message + delay(3000); }