diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml index 6e2ab7285f..74d61608a7 100644 --- a/.github/workflows/compile.yml +++ b/.github/workflows/compile.yml @@ -28,6 +28,13 @@ jobs: matrix: presets: [ gcc-debug, gcc-release, clang-debug, clang-release, gcc-debug-dynamic, clang-debug-dynamic ] + # permissions to publish test results + permissions: + contents: read + issues: read + checks: write + pull-requests: write + steps: - uses: actions/checkout@v4 with: @@ -58,3 +65,10 @@ jobs: configurePreset: '${{ matrix.presets }}' buildPreset: '${{ matrix.presets }}' testPreset: '${{ matrix.presets }}' + + - name: Publish Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + files: | + */**/test-results.xml diff --git a/cmake b/cmake index 63de35dc95..e6760d6d2a 160000 --- a/cmake +++ b/cmake @@ -1 +1 @@ -Subproject commit 63de35dc95ad446f2a27141141920e75d0f3dd66 +Subproject commit e6760d6d2a78a0e25c7e778c75196ad96533c667 diff --git a/exes/CMakeLists.txt b/exes/CMakeLists.txt index 91d12de1f2..cc3d27c30b 100644 --- a/exes/CMakeLists.txt +++ b/exes/CMakeLists.txt @@ -6,4 +6,4 @@ add_subdirectory(tfcctl) add_subdirectory(ipc-ruler) add_subdirectory(signal_source) add_subdirectory(mqtt-bridge) -add_subdirectory(themis) \ No newline at end of file +add_subdirectory(themis) diff --git a/exes/ethercat/CMakeLists.txt b/exes/ethercat/CMakeLists.txt index c0c8a62d37..1fe8eef5e4 100644 --- a/exes/ethercat/CMakeLists.txt +++ b/exes/ethercat/CMakeLists.txt @@ -25,7 +25,7 @@ target_link_libraries(ec tfc::soem_interface tfc::stx tfc::motor - mp-units::si + mp-units::systems fmt::fmt ) diff --git a/exes/ethercat/inc/public/tfc/ec/devices/base.hpp b/exes/ethercat/inc/public/tfc/ec/devices/base.hpp index 4b5b1bdff3..304ae2a5a2 100644 --- a/exes/ethercat/inc/public/tfc/ec/devices/base.hpp +++ b/exes/ethercat/inc/public/tfc/ec/devices/base.hpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include #include diff --git a/exes/ethercat/inc/public/tfc/ec/devices/beckhoff/EL3xxx.hpp b/exes/ethercat/inc/public/tfc/ec/devices/beckhoff/EL3xxx.hpp index 209b9546f9..6d2c5892cd 100644 --- a/exes/ethercat/inc/public/tfc/ec/devices/beckhoff/EL3xxx.hpp +++ b/exes/ethercat/inc/public/tfc/ec/devices/beckhoff/EL3xxx.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include diff --git a/exes/ethercat/inc/public/tfc/ec/devices/eilersen/4x60a.hpp b/exes/ethercat/inc/public/tfc/ec/devices/eilersen/4x60a.hpp index fa79083fb4..501686d8b1 100644 --- a/exes/ethercat/inc/public/tfc/ec/devices/eilersen/4x60a.hpp +++ b/exes/ethercat/inc/public/tfc/ec/devices/eilersen/4x60a.hpp @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include diff --git a/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320.hpp b/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320.hpp index 4fd8242132..1a32a6afa3 100644 --- a/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320.hpp +++ b/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320.hpp @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/dbus-iface.hpp b/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/dbus-iface.hpp index eaee490abd..9aa9f9a6eb 100644 --- a/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/dbus-iface.hpp +++ b/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/dbus-iface.hpp @@ -7,9 +7,9 @@ #include -#include #include -#include +#include +#include #include #include diff --git a/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/settings.hpp b/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/settings.hpp index ceb192efee..73a069d37c 100644 --- a/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/settings.hpp +++ b/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/settings.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/speedratio.hpp b/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/speedratio.hpp index b64e450a9a..5ff230d74a 100644 --- a/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/speedratio.hpp +++ b/exes/ethercat/inc/public/tfc/ec/devices/schneider/atv320/speedratio.hpp @@ -1,6 +1,6 @@ #pragma once #include -#include +#include namespace tfc::ec::devices::schneider::atv320::detail { diff --git a/exes/ethercat/inc/public/tfc/ec/devices/util.hpp b/exes/ethercat/inc/public/tfc/ec/devices/util.hpp index d0ba723be1..011f456bee 100644 --- a/exes/ethercat/inc/public/tfc/ec/devices/util.hpp +++ b/exes/ethercat/inc/public/tfc/ec/devices/util.hpp @@ -4,8 +4,8 @@ #include -#include -#include +#include +#include #include #include #include diff --git a/exes/ethercat/tests/test_ec_util.cpp b/exes/ethercat/tests/test_ec_util.cpp index 060a88415c..db06c51657 100644 --- a/exes/ethercat/tests/test_ec_util.cpp +++ b/exes/ethercat/tests/test_ec_util.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include #include @@ -49,8 +49,8 @@ auto main(int, char**) -> int { "setting to json"_test = []() { [[maybe_unused]] example::trivial_type_setting const test{}; auto const json = glz::write_json(test); - expect(json == "13") << "got: " << json; - auto const exp = glz::read_json(json); + expect(json == "13") << "got: " << json.value_or(""); + auto const exp = glz::read_json(json.value_or("")); expect(exp.has_value() >> fatal); expect(exp.value() == test); }; diff --git a/exes/gpio/src/gpio.cpp b/exes/gpio/src/gpio.cpp index 9b0bca3197..0da7ed5750 100644 --- a/exes/gpio/src/gpio.cpp +++ b/exes/gpio/src/gpio.cpp @@ -16,8 +16,8 @@ void gpio::pin_direction_change(pin_index_t idx, gpiod::line::direction new_value, [[maybe_unused]] gpiod::line::direction old_value) noexcept { try { - logger_.trace(R"(Got new direction change with new value: "{}", old value: "{}")", glz::write_json(new_value), - glz::write_json(old_value)); + logger_.trace(R"(Got new direction change with new value: "{}", old value: "{}")", + glz::write_json(new_value).value_or(""), glz::write_json(old_value).value_or("")); if (new_value == gpiod::line::direction::OUTPUT) { pins_.at(idx).emplace(ctx_, manager_client_, fmt::format("in.{}", idx), std::bind_front(&gpio::ipc_event, this, idx)); @@ -46,8 +46,8 @@ void gpio::pin_direction_change(pin_index_t idx, void gpio::pin_edge_change(pin_index_t idx, gpiod::line::edge new_value, [[maybe_unused]] gpiod::line::edge old_value) noexcept { - logger_.trace(R"(Got new edge change with new value: "{}", old value: "{}")", glz::write_json(new_value), - glz::write_json(old_value)); + logger_.trace(R"(Got new edge change with new value: "{}", old value: "{}")", glz::write_json(new_value).value_or(""), + glz::write_json(old_value).value_or("")); auto settings{ chip_.prepare_request().get_line_config().get_line_settings().at(idx) }; settings.set_edge_detection(new_value); chip_.prepare_request().add_line_settings(idx, settings); @@ -60,8 +60,8 @@ void gpio::pin_edge_change(pin_index_t idx, void gpio::pin_bias_change(pin_index_t idx, gpiod::line::bias new_value, [[maybe_unused]] gpiod::line::bias old_value) noexcept { - logger_.trace(R"(Got new bias change with new value: "{}", old value: "{}")", glz::write_json(new_value), - glz::write_json(old_value)); + logger_.trace(R"(Got new bias change with new value: "{}", old value: "{}")", glz::write_json(new_value).value_or(""), + glz::write_json(old_value).value_or("")); auto settings{ chip_.prepare_request().get_line_config().get_line_settings().at(idx) }; settings.set_bias(new_value); chip_.prepare_request().add_line_settings(idx, settings).do_request(); @@ -69,8 +69,8 @@ void gpio::pin_bias_change(pin_index_t idx, void gpio::pin_force_change(pin_index_t idx, pin::out::force_e new_value, [[maybe_unused]] pin::out::force_e old_value) noexcept { - logger_.trace(R"(Got new force change with new value: "{}", old value: "{}")", glz::write_json(new_value), - glz::write_json(old_value)); + logger_.trace(R"(Got new force change with new value: "{}", old value: "{}")", glz::write_json(new_value).value_or(""), + glz::write_json(old_value).value_or("")); auto settings{ chip_.prepare_request().get_line_config().get_line_settings().at(idx) }; switch (new_value) { using enum pin::out::force_e; @@ -104,8 +104,8 @@ void gpio::pin_force_change(pin_index_t idx, void gpio::pin_drive_change(pin_index_t idx, gpiod::line::drive new_value, [[maybe_unused]] gpiod::line::drive old_value) noexcept { - logger_.trace(R"(Got new drive change with new value: "{}", old value: "{}")", glz::write_json(new_value), - glz::write_json(old_value)); + logger_.trace(R"(Got new drive change with new value: "{}", old value: "{}")", glz::write_json(new_value).value_or(""), + glz::write_json(old_value).value_or("")); auto settings{ chip_.prepare_request().get_line_config().get_line_settings().at(idx) }; settings.set_drive(new_value); chip_.prepare_request().add_line_settings(idx, settings).do_request(); diff --git a/exes/mqtt-bridge/inc/client.hpp b/exes/mqtt-bridge/inc/client.hpp index 22f836d0a4..43c38cfe7c 100644 --- a/exes/mqtt-bridge/inc/client.hpp +++ b/exes/mqtt-bridge/inc/client.hpp @@ -8,7 +8,7 @@ #include #include -#include +#include #include #include @@ -67,7 +67,7 @@ class client { logger_.trace("Sending MQTT connection packet..."); - auto send_error = co_await endpoint_client_->send(connect_packet(), asio::use_awaitable); + auto [send_error] = co_await endpoint_client_->send(connect_packet(), as_tuple(asio::use_awaitable)); if (send_error) { co_return false; @@ -102,10 +102,10 @@ class client { p_id = 0; } - auto pub_packet = async_mqtt::v5::publish_packet{ p_id.value(), async_mqtt::allocate_buffer(topic), - async_mqtt::allocate_buffer(payload), qos }; + auto pub_packet = async_mqtt::v5::publish_packet{ p_id.value(), topic, payload, qos }; - co_return !(co_await endpoint_client_->send(pub_packet, asio::use_awaitable)); + auto [error] = co_await endpoint_client_->send(pub_packet, as_tuple(asio::use_awaitable)); + co_return !error; } auto subscribe_to_topic(std::string topic) -> asio::awaitable { @@ -116,13 +116,11 @@ class client { p_id = endpoint_client_->acquire_unique_packet_id(); - std::string_view topic_view{ topic }; + auto sub_packet = + async_mqtt::v5::subscribe_packet{ p_id.value(), + { { topic, async_mqtt::qos::at_most_once | async_mqtt::sub::nl::yes } } }; - auto sub_packet = async_mqtt::v5::subscribe_packet{ - p_id.value(), { { async_mqtt::buffer(topic_view), async_mqtt::qos::at_most_once | async_mqtt::sub::nl::yes } } - }; - - auto send_error = co_await endpoint_client_->send(sub_packet, asio::use_awaitable); + auto [send_error] = co_await endpoint_client_->send(sub_packet, as_tuple(asio::use_awaitable)); if (send_error) { co_return !send_error; @@ -145,7 +143,7 @@ class client { for (auto const& entry : suback_packet.entries()) { if (entry != async_mqtt::suback_reason_code::granted_qos_0) { logger_.error("Error subscribing to topic: {}, reason code: {}", topic.data(), - async_mqtt::suback_reason_code_to_str(entry)); + async_mqtt::suback_reason_code_to_string(entry)); co_return false; } } @@ -173,13 +171,13 @@ class client { logger_.trace("Received PUBLISH packet. Parsing payload..."); - for (uint64_t i = 0; i < publish_packet->payload().size(); i++) { - process_payload(publish_packet->payload()[i], *publish_packet); + for (auto const& itm : publish_packet->payload_as_buffer()) { + process_payload(itm, *publish_packet); } } } - auto strand() -> asio::strand { return endpoint_client_->strand(); } + auto get_executor() const -> asio::any_io_executor { return endpoint_client_->get_executor(); } auto set_initial_message(std::string const& topic, std::string const& payload, async_mqtt::qos const& qos) -> void { initial_message_ = std::tuple{ topic, payload, qos }; @@ -195,26 +193,25 @@ class client { auto connect_packet() -> async_mqtt::v5::connect_packet { if (config_.value().username.empty() || config_.value().password.empty()) { - return async_mqtt::v5::connect_packet{ true, - std::chrono::seconds(100).count(), - async_mqtt::allocate_buffer(config_.value().client_id), - async_mqtt::will( - async_mqtt::allocate_buffer(mqtt_will_topic_), - async_mqtt::buffer(std::string_view{ mqtt_will_payload_ }), - { async_mqtt::qos::at_least_once | async_mqtt::pub::retain::no }), - async_mqtt::nullopt, - async_mqtt::nullopt, - { async_mqtt::property::session_expiry_interval{ 0 } } }; + return async_mqtt::v5::connect_packet{ + true, + std::chrono::seconds(100).count(), + config_.value().client_id, + async_mqtt::will(mqtt_will_topic_, async_mqtt::buffer(std::string_view{ mqtt_will_payload_ }), + { async_mqtt::qos::at_least_once | async_mqtt::pub::retain::no }), + std::nullopt, + std::nullopt, + { async_mqtt::property::session_expiry_interval{ 0 } } + }; } return async_mqtt::v5::connect_packet{ true, std::chrono::seconds(100).count(), - async_mqtt::allocate_buffer(config_.value().client_id), + config_.value().client_id, async_mqtt::will( - async_mqtt::allocate_buffer(mqtt_will_topic_), - async_mqtt::buffer(std::string_view{ mqtt_will_payload_ }), + mqtt_will_topic_, async_mqtt::buffer(std::string_view{ mqtt_will_payload_ }), { async_mqtt::qos::at_least_once | async_mqtt::pub::retain::no }), - async_mqtt::allocate_buffer(config_.value().username), - async_mqtt::allocate_buffer(config_.value().password), + config_.value().username, + config_.value().password, { async_mqtt::property::session_expiry_interval{ 0 } } }; } diff --git a/exes/mqtt-bridge/inc/endpoint.hpp b/exes/mqtt-bridge/inc/endpoint.hpp index 2daa31bd89..a6a3dd09cf 100644 --- a/exes/mqtt-bridge/inc/endpoint.hpp +++ b/exes/mqtt-bridge/inc/endpoint.hpp @@ -6,9 +6,11 @@ #include #include +#include #include #include -#include +#include +#include #include #include @@ -35,34 +37,34 @@ class endpoint_client { } } - auto strand() -> asio::strand { + auto get_executor() const -> asio::any_io_executor { if (mqtts_client_) { - return mqtts_client_->strand(); + return mqtts_client_->get_executor(); } - return mqtt_client_->strand(); + return mqtt_client_->get_executor(); } auto recv(async_mqtt::control_packet_type packet_t) { if (mqtts_client_) { - return mqtts_client_->recv(async_mqtt::filter::match, { packet_t }, asio::use_awaitable); + return mqtts_client_->async_recv(async_mqtt::filter::match, { packet_t }, asio::use_awaitable); } - return mqtt_client_->recv(async_mqtt::filter::match, { packet_t }, asio::use_awaitable); + return mqtt_client_->async_recv(async_mqtt::filter::match, { packet_t }, asio::use_awaitable); } template auto send(args_t&&... args) { if (mqtts_client_) { - return mqtts_client_->send(std::forward(args)...); + return mqtts_client_->async_send(std::forward(args)...); } - return mqtt_client_->send(std::forward(args)...); + return mqtt_client_->async_send(std::forward(args)...); } template auto close(args_t&&... args) { if (mqtts_client_) { - return mqtts_client_->close(std::forward(args)...); + return mqtts_client_->async_close(std::forward(args)...); } - return mqtt_client_->close(std::forward(args)...); + return mqtt_client_->async_close(std::forward(args)...); } auto acquire_unique_packet_id() { diff --git a/exes/mqtt-bridge/inc/run.hpp b/exes/mqtt-bridge/inc/run.hpp index 631a179ea3..119a468b9b 100644 --- a/exes/mqtt-bridge/inc/run.hpp +++ b/exes/mqtt-bridge/inc/run.hpp @@ -55,14 +55,15 @@ class run { bool restart_needed = false; co_spawn( - sp_interface_.strand(), + sp_interface_.get_executor(), sp_interface_.wait_for_payloads(std::bind_front(&spark_plug::process_payload, &sp_interface_), restart_needed), bind_cancellation_slot(cancel_signal.slot(), asio::detached)); io_ctx_.run_for(std::chrono::seconds{ 1 }); while (!restart_needed) { - co_await asio::steady_timer{ sp_interface_.strand(), std::chrono::seconds{ 5 } }.async_wait(asio::use_awaitable); + co_await asio::steady_timer{ sp_interface_.get_executor(), std::chrono::seconds{ 5 } }.async_wait( + asio::use_awaitable); } cancel_signal.emit(asio::cancellation_type::all); diff --git a/exes/mqtt-bridge/inc/spark_plug_interface.hpp b/exes/mqtt-bridge/inc/spark_plug_interface.hpp index 3a1d5f23a9..73a0092bcb 100644 --- a/exes/mqtt-bridge/inc/spark_plug_interface.hpp +++ b/exes/mqtt-bridge/inc/spark_plug_interface.hpp @@ -123,7 +123,7 @@ class spark_plug_interface { mqtt_client_->set_initial_message(topic, payload_string, async_mqtt::qos::at_most_once); - asio::co_spawn(mqtt_client_->strand(), + asio::co_spawn(mqtt_client_->get_executor(), mqtt_client_->send_message(topic, payload_string, async_mqtt::qos::at_most_once), asio::detached); } } @@ -141,7 +141,7 @@ class spark_plug_interface { } auto update_value(structs::spark_plug_b_variable& variable) -> void { - asio::co_spawn(strand(), update_value_impl(variable), asio::detached); + asio::co_spawn(get_executor(), update_value_impl(variable), asio::detached); } auto update_value_impl(structs::spark_plug_b_variable& variable) -> asio::awaitable { @@ -245,7 +245,7 @@ class spark_plug_interface { } } - auto strand() -> asio::strand { return mqtt_client_->strand(); } + auto get_executor() const -> asio::any_io_executor { return mqtt_client_->get_executor(); } static auto set_value_payload(Payload_Metric* metric, std::optional const& value, logger::logger const& logger) -> void { diff --git a/exes/mqtt-bridge/tests/inc/broker/broker.hpp b/exes/mqtt-bridge/tests/inc/broker/broker.hpp new file mode 100644 index 0000000000..3a55a2c30a --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/broker.hpp @@ -0,0 +1,1413 @@ +#pragma once +// Copyright Takatoshi Kondo 2022 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_BROKER_HPP) +#define ASYNC_MQTT_BROKER_BROKER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace async_mqtt { + +template +class broker { + using epsp_type = epsp_wrap; + using this_type = broker; + +public: + broker(as::io_context& timer_ioc, bool recycling_allocator = false) + : timer_ioc_{ timer_ioc }, tim_disconnect_{ timer_ioc_ }, recycling_allocator_{ recycling_allocator } { + std::unique_lock g_sec{ mtx_security_ }; + security_.default_config(); + } + + void handle_accept(epsp_type epsp, std::optional preauthed_user_name = {}) { + epsp.set_preauthed_user_name(force_move(preauthed_user_name)); + async_read_packet(force_move(epsp)); + } + + /** + * @brief configure the security settings + */ + void set_security(security&& sec) { + std::unique_lock g_sec{ mtx_security_ }; + security_ = force_move(sec); + } + +private: + void async_read_packet(epsp_type epsp) { + auto recv_proc = [this, epsp](error_code const& ec, packet_variant pv) mutable { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + close_proc(force_move(epsp), + true // send_will + ); + return; + } + BOOST_ASSERT(pv); + pv.visit(overload{ + [&](v3_1_1::connect_packet& p) { + connect_handler(force_move(epsp), p.client_id(), p.user_name(), p.password(), p.get_will(), p.clean_session(), + p.keep_alive(), properties{}); + }, + [&](v5::connect_packet& p) { + connect_handler(force_move(epsp), p.client_id(), p.user_name(), p.password(), p.get_will(), p.clean_start(), + p.keep_alive(), p.props()); + }, + [&](v3_1_1::publish_packet& p) { + publish_handler(force_move(epsp), p.packet_id(), p.opts(), p.topic(), p.payload_as_buffer(), properties{}); + }, + [&](v5::publish_packet& p) { + publish_handler(force_move(epsp), p.packet_id(), p.opts(), p.topic(), p.payload_as_buffer(), p.props()); + }, + [&](v3_1_1::puback_packet& p) { + puback_handler(force_move(epsp), p.packet_id(), puback_reason_code::success, properties{}); + }, + [&](v5::puback_packet& p) { puback_handler(force_move(epsp), p.packet_id(), p.code(), p.props()); }, + [&](v3_1_1::pubrec_packet& p) { + pubrec_handler(force_move(epsp), p.packet_id(), pubrec_reason_code::success, properties{}); + }, + [&](v5::pubrec_packet& p) { pubrec_handler(force_move(epsp), p.packet_id(), p.code(), p.props()); }, + [&](v3_1_1::pubrel_packet& p) { + pubrel_handler(force_move(epsp), p.packet_id(), pubrel_reason_code::success, properties{}); + }, + [&](v5::pubrel_packet& p) { pubrel_handler(force_move(epsp), p.packet_id(), p.code(), p.props()); }, + [&](v3_1_1::pubcomp_packet& p) { + pubcomp_handler(force_move(epsp), p.packet_id(), pubcomp_reason_code::success, properties{}); + }, + [&](v5::pubcomp_packet& p) { pubcomp_handler(force_move(epsp), p.packet_id(), p.code(), p.props()); }, + [&](v3_1_1::subscribe_packet& p) { + subscribe_handler(force_move(epsp), p.packet_id(), p.entries(), properties{}); + }, + [&](v5::subscribe_packet& p) { subscribe_handler(force_move(epsp), p.packet_id(), p.entries(), p.props()); }, + [&](v3_1_1::suback_packet&) { + // TBD receive invalid packet + }, + [&](v5::suback_packet&) { + // TBD receive invalid packet + }, + [&](v3_1_1::unsubscribe_packet& p) { + unsubscribe_handler(force_move(epsp), p.packet_id(), p.entries(), properties{}); + }, + [&](v5::unsubscribe_packet& p) { unsubscribe_handler(force_move(epsp), p.packet_id(), p.entries(), p.props()); }, + [&](v3_1_1::pingreq_packet&) { pingreq_handler(force_move(epsp)); }, + [&](v5::pingreq_packet&) { pingreq_handler(force_move(epsp)); }, + [&](v3_1_1::disconnect_packet&) { + disconnect_handler(force_move(epsp), disconnect_reason_code::normal_disconnection, properties{}); + }, + [&](v5::disconnect_packet& p) { disconnect_handler(force_move(epsp), p.code(), p.props()); }, + [&](v5::auth_packet& p) { auth_handler(force_move(epsp), p.code(), p.props()); }, + [&](auto const&) { + ASYNC_MQTT_LOG("mqtt_broker", fatal) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "invalid variant"; + } }); + }; + + if (recycling_allocator_) { + epsp.async_recv(as::bind_allocator(as::recycling_allocator(), recv_proc)); + } else { + epsp.async_recv(recv_proc); + } + } + + void connect_handler(epsp_type epsp, + std::string client_id, + std::optional noauth_username, + std::optional password, + std::optional will, + bool clean_start, + std::uint16_t /*keep_alive*/, + properties props) { + std::optional username; + if (auto paun_opt = epsp.get_preauthed_user_name()) { + std::shared_lock g_sec{ mtx_security_ }; + if (security_.login_cert(*paun_opt)) { + username = force_move(*paun_opt); + } + } else if (!noauth_username && !password) { + std::shared_lock g_sec{ mtx_security_ }; + username = security_.login_anonymous(); + } else if (noauth_username && password) { + std::shared_lock g_sec{ mtx_security_ }; + username = security_.login(*noauth_username, *password); + } + + // If login fails, try the unauthenticated user + if (!username) { + std::shared_lock g_sec{ mtx_security_ }; + username = security_.login_unauthenticated(); + } + + std::optional session_expiry_interval; + std::optional will_expiry_interval; + bool response_topic_requested = false; + + auto version = epsp.get_protocol_version(); + if (version == protocol_version::v5) { + for (auto const& prop : props) { + prop.visit(overload{ [&](property::session_expiry_interval const& v) { + if (v.val() != 0) { + session_expiry_interval.emplace(std::chrono::seconds(v.val())); + } + }, + [&](property::request_response_information const& v) { response_topic_requested = v.val(); }, + [&](auto const&) {} }); + } + if (will) { + for (auto const& prop : will->props()) { + prop.visit(overload{ [&](property::message_expiry_interval const& v) { + will_expiry_interval.emplace(std::chrono::seconds(v.val())); + }, + [&](auto const&) {} }); + } + } + + // for test + if (h_connect_props_) { + h_connect_props_(props); + } + } + + properties connack_props; + + if (!username) { + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) + << "User failed to login: " << (noauth_username ? std::string(*noauth_username) : std::string("anonymous user")); + + send_connack(epsp, + false, // session present + false, // authenticated + force_move(connack_props), [epsp, version](error_code) mutable { + disconnect_and_close(epsp, version, disconnect_reason_code::not_authorized, as::detached); + }); + return; + } + + if (client_id.empty()) { + if (!handle_empty_client_id(epsp, client_id, clean_start, connack_props)) { + return; + } + // A new client id was generated + client_id = epsp.get_client_id(); + } + + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "User logged in as: '" + << *username << "', client_id: " << client_id; + + /** + * http://docs.oasis-open.org/mqtt/mqtt/v5.0/cs02/mqtt-v5.0-cs02.html#_Toc514345311 + * 3.1.2.4 Clean Start + * If a CONNECT packet is received with Clean Start is set to 1, the Client and Server MUST + * discard any existing Session and start a new Session [MQTT-3.1.2-4]. Consequently, + * the Session Present flag in CONNACK is always set to 0 if Clean Start is set to 1. + */ + + // Find any sessions that have the same client_id + std::lock_guard g(mtx_sessions_); + auto& idx = sessions_.template get(); + auto it = idx.lower_bound(std::make_tuple(*username, client_id)); + if (it == idx.end() || (*it)->client_id() != client_id || (*it)->get_username() != *username) { + // new connection + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "cid:" << client_id << " new connection inserted."; + it = idx.emplace_hint( + it, session_state::create( + timer_ioc_, mtx_subs_map_, subs_map_, shared_targets_, epsp, client_id, *username, force_move(will), + // will_sender + [this](auto&&... params) { this->do_publish(std::forward(params)...); }, clean_start, + force_move(will_expiry_interval), force_move(session_expiry_interval))); + if (response_topic_requested) { + // set_response_topic never modify key part + set_response_topic(const_cast&>(**it), connack_props, *username); + } + + send_connack(epsp, + false, // session present + true, // authenticated + force_move(connack_props), [this, epsp](error_code) mutable { async_read_packet(force_move(epsp)); }); + } else if (auto old_epsp = const_cast&>(**it).lock()) { + // online overwrite + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "cid:" << client_id << " old connection " + << old_epsp.get_address() << " exists and is online. close it "; + close_proc_no_lock( + old_epsp, true, disconnect_reason_code::session_taken_over, + [this, epsp, &idx, it, connack_props = force_move(connack_props), clean_start, client_id = force_move(client_id), + response_topic_requested, username = force_move(username), will = force_move(will), will_expiry_interval, + session_expiry_interval](bool remain_as_offline) mutable { + if (remain_as_offline) { + // offline exists -> online + offline_to_online(force_move(epsp), force_move(will), force_move(will_expiry_interval), + force_move(session_expiry_interval), clean_start, force_move(*username), idx, it, + response_topic_requested, force_move(connack_props)); + } else { + // new connection + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, this) << "cid:" << client_id + << "online connection exists, discard old one due to session_expiry and renew"; + bool inserted; + std::tie(it, inserted) = idx.emplace(session_state::create( + timer_ioc_, mtx_subs_map_, subs_map_, shared_targets_, epsp, client_id, *username, force_move(will), + // will_sender + [this](auto&&... params) { this->do_publish(std::forward(params)...); }, clean_start, + force_move(will_expiry_interval), force_move(session_expiry_interval))); + BOOST_ASSERT(inserted); + if (response_topic_requested) { + // set_response_topic never modify key part + set_response_topic(const_cast&>(**it), connack_props, *username); + } + send_connack(epsp, + false, // session present + true, // authenticated + force_move(connack_props), + [this, epsp](error_code) mutable { async_read_packet(force_move(epsp)); }); + } + }); + } else { + // offline exists -> online + offline_to_online(force_move(epsp), force_move(will), force_move(will_expiry_interval), + force_move(session_expiry_interval), clean_start, force_move(*username), idx, it, + response_topic_requested, force_move(connack_props)); + } + } + + template + void offline_to_online(epsp_type epsp, + std::optional will, + std::optional will_expiry_interval, + std::optional session_expiry_interval, + bool clean_start, + std::string username, + Idx& idx, + It it, + bool response_topic_requested, + properties connack_props) { + if (clean_start) { + // discard offline session + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, this) << "un:" << username + << "offline connection exists, discard old one due to new one's clean_start and renew"; + idx.modify( + it, [&](auto& e) { e->renew(epsp, will, clean_start, will_expiry_interval, session_expiry_interval); }, + [](auto&) { BOOST_ASSERT(false); }); + if (response_topic_requested) { + // set_response_topic never modify key part + set_response_topic(const_cast&>(**it), connack_props, username); + } + send_connack(epsp, + false, // session present + true, // authenticated + force_move(connack_props), [this, epsp](error_code) mutable { async_read_packet(force_move(epsp)); }); + } else { + // inherit online session if previous session's session exists + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, this) << "un:" << username << "offline connection exists, and inherit it"; + if (response_topic_requested) { + // set_response_topic never modify key part + set_response_topic(const_cast&>(**it), connack_props, username); + } + + epsp.dispatch([this, epsp, &idx, it, username = force_move(username), will = force_move(will), will_expiry_interval, + session_expiry_interval, connack_props = force_move(connack_props)]() mutable { + idx.modify( + it, + [&](auto& e) { e->inherit(epsp, force_move(will), will_expiry_interval, force_move(session_expiry_interval)); }, + [](auto&) { BOOST_ASSERT(false); }); + send_connack(epsp, + true, // session present + true, // authenticated + force_move(connack_props), [this, epsp, &idx, it](error_code const& ec) mutable { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << ec.message(); + return; + } + idx.modify( + it, + [&](auto& e) { + e->send_inflight_messages(); + e->send_all_offline_messages(); + }, + [](auto&) { BOOST_ASSERT(false); }); + async_read_packet(force_move(epsp)); + }); + }); + } + } + + void set_response_topic(session_state& s, properties& connack_props, std::string const& username) { + auto response_topic = [&] { + if (auto rt_opt = s.get_response_topic()) { + return *rt_opt; + } + auto rt = create_uuid_string(); + s.set_response_topic(rt); + return rt; + }(); + + auto rule_nr = [&] { + std::unique_lock g_sec{ mtx_security_ }; + return security_.add_auth(response_topic, { "@any" }, security::authorization::type::allow, { username }, + security::authorization::type::allow); + }(); + + s.set_clean_handler([this, response_topic, rule_nr]() { + std::lock_guard g(mtx_retains_); + retains_.erase(response_topic); + std::unique_lock g_sec{ mtx_security_ }; + security_.remove_auth(rule_nr); + }); + + connack_props.emplace_back(property::response_information{ force_move(response_topic) }); + } + + bool handle_empty_client_id(epsp_type& epsp, std::string const& client_id, bool clean_start, properties& connack_props) { + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: + if (client_id.empty()) { + if (clean_start) { + epsp.set_client_id(create_uuid_string()); + } else { + // https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349242 + // If the Client supplies a zero-byte ClientId, + // the Client MUST also set CleanSession to 1 [MQTT-3.1.3-7]. + // If it's a not a clean session, but no client id is provided, + // we would have no way to map this connection's session to a new connection later. + // So the connection must be rejected. + if (connack_) { + epsp.async_send(v3_1_1::connack_packet{ false, connect_return_code::identifier_rejected }, + [epsp](error_code const& ec) mutable { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + epsp.async_close(as::bind_executor(epsp.get_executor(), [epsp] { + ASYNC_MQTT_LOG("mqtt_broker", info) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "closed"; + })); + }); + } + return false; + } + } + break; + case protocol_version::v5: + if (client_id.empty()) { + // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901059 + // A Server MAY allow a Client to supply a ClientID that has a length of zero bytes, + // however if it does so the Server MUST treat this as a special case and assign a + // unique ClientID to that Client [MQTT-3.1.3-6]. It MUST then process the + // CONNECT packet as if the Client had provided that unique ClientID, + // and MUST return the Assigned Client Identifier in the CONNACK packet [MQTT-3.1.3-7]. + // If the Server rejects the ClientID it MAY respond to the CONNECT packet with a CONNACK + // using Reason Code 0x85 (Client Identifier not valid) as described in section 4.13 + // Handling errors, and then it MUST close the Network Connection [MQTT-3.1.3-8]. + // + // mqtt_cpp author's note: On v5.0, no Clean Start restriction is described. + epsp.set_client_id(create_uuid_string()); + connack_props.emplace_back(property::assigned_client_identifier{ std::string{ epsp.get_client_id() } }); + } + break; + default: + BOOST_ASSERT(false); + return false; + } + return true; + } + + struct send_connack_op { + this_type& brk; + epsp_type epsp; + bool session_present; + bool authenticated; + properties props; + + template + void operator()(Self& self) { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "send_connack"; + // Reply to the connect message. + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: + if (brk.connack_) { + epsp.async_send( + v3_1_1::connack_packet{ + session_present, + authenticated ? connect_return_code::accepted : connect_return_code::not_authorized, + }, + force_move(self)); + } + break; + case protocol_version::v5: + // connack_props_ member varible is for testing + if (brk.connack_props_.empty()) { + // props local variable is is for real case + props.emplace_back(property::topic_alias_maximum{ topic_alias_max }); + props.emplace_back(property::receive_maximum{ receive_maximum_max }); + if (brk.connack_) { + epsp.async_send( + v5::connack_packet{ session_present, + authenticated ? connect_reason_code::success : connect_reason_code::not_authorized, + force_move(props) }, + force_move(self)); + } + } else { + // use connack_props_ for testing + if (brk.connack_) { + epsp.async_send( + v5::connack_packet{ session_present, + authenticated ? connect_reason_code::success : connect_reason_code::not_authorized, + brk.connack_props_ }, + force_move(self)); + } + } + break; + default: + BOOST_ASSERT(false); + break; + } + } + + template + void operator()(Self& self, error_code se) { + self.complete(se); + } + }; + + template + auto send_connack(epsp_type& epsp, bool session_present, bool authenticated, properties props, CompletionToken&& token) { + auto exe = epsp.get_executor(); + return as::async_compose( + send_connack_op{ *this, force_move(epsp), session_present, authenticated, force_move(props) }, token, exe); + } + + void publish_handler(epsp_type epsp, + packet_id_type packet_id, + pub::opts opts, + std::string topic, + std::vector payload, + properties props) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + + std::shared_lock g(mtx_sessions_); + auto& idx = sessions_.template get(); + auto it = idx.find(epsp); + + // broker uses async_* APIs + // If broker erase a connection, then async_force_disconnect() + // and/or async_force_disconnect () is called. + // During async operation, spep is valid but it has already been + // erased from sessions_ + if (it == idx.end()) + return; + + auto send_pubres = [&](bool authorized, bool matched) { + switch (opts.get_qos()) { + case qos::at_least_once: + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: + epsp.async_send(v3_1_1::puback_packet{ packet_id }, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + case protocol_version::v5: { + auto packet = [&] { + if (authorized) { + if (puback_props_.empty()) { + if (matched) { + return v5::puback_packet{ packet_id }; + } else { + return v5::puback_packet{ packet_id, puback_reason_code::no_matching_subscribers }; + } + } else { + if (matched) { + return v5::puback_packet{ packet_id, puback_reason_code::success, puback_props_ }; + } else { + return v5::puback_packet{ packet_id, puback_reason_code::no_matching_subscribers, puback_props_ }; + } + }; + } else { + if (puback_props_.empty()) { + return v5::puback_packet{ packet_id, puback_reason_code::not_authorized }; + } else { + return v5::puback_packet{ packet_id, puback_reason_code::not_authorized, puback_props_ }; + } + } + }(); + epsp.async_send(force_move(packet), [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + } break; + default: + BOOST_ASSERT(false); + break; + } + break; + case qos::exactly_once: + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: + epsp.async_send(v3_1_1::pubrec_packet{ packet_id }, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + case protocol_version::v5: { + auto packet = [&] { + if (authorized) { + if (pubrec_props_.empty()) { + if (matched) { + return v5::pubrec_packet{ packet_id }; + } else { + return v5::pubrec_packet{ packet_id, pubrec_reason_code::no_matching_subscribers }; + } + } else { + if (matched) { + return v5::pubrec_packet{ packet_id, pubrec_reason_code::success, pubrec_props_ }; + } else { + return v5::pubrec_packet{ packet_id, pubrec_reason_code::no_matching_subscribers, pubrec_props_ }; + } + }; + } else { + if (pubrec_props_.empty()) { + return v5::pubrec_packet{ packet_id, pubrec_reason_code::not_authorized }; + } else { + return v5::pubrec_packet{ packet_id, pubrec_reason_code::not_authorized, pubrec_props_ }; + } + } + }(); + epsp.async_send(force_move(packet), [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + } break; + default: + BOOST_ASSERT(false); + break; + } + break; + default: + break; + } + }; + + // See if this session is authorized to publish this topic + if ([&] { + std::shared_lock g_sec{ mtx_security_ }; + return security_.auth_pub(topic, (*it)->get_username()) != security::authorization::type::allow; + }()) { + // Publish not authorized + send_pubres(false, false); + return; + } + + properties forward_props; + + for (auto&& prop : props) { + force_move(prop).visit(overload{ [&](property::topic_alias&&) { + // TopicAlias is not forwarded + // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901113 + // A receiver MUST NOT carry forward any Topic Alias mappings from + // one Network Connection to another [MQTT-3.3.2-7]. + }, + [&](property::subscription_identifier&& p) { + ASYNC_MQTT_LOG("mqtt_broker", warning) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) + << "Subscription Identifier from client not forwarded sid:" << p.val(); + }, + [&](auto&& p) { forward_props.push_back(force_move(p)); } }); + } + + bool matched = do_publish(**it, force_move(topic), force_move(payload), + opts.get_qos() | opts.get_retain(), // remove dup flag + force_move(forward_props)); + + send_pubres(true, matched); + } + + /** + * @brief do_publish Publish a message to any subscribed clients. + * + * @param source_ss - soource session_state. + * @param topic - The topic to publish the message on. + * @param payload - The payload of the message. + * @param pubopts - publish options + * @param props - properties + */ + bool do_publish(session_state const& source_ss, + std::string topic, + std::vector payload, + pub::opts opts, + properties props) { + bool matched = false; + + // Get auth rights for this topic + // auth_users prepared once here, and then referred multiple times in subs_map_.modify() for efficiency + auto auth_users = [&] { + std::shared_lock g_sec{ mtx_security_ }; + return security_.auth_sub(topic); + }(); + + // publish the message to subscribers. + // retain is delivered as the original only if rap_value is rap::retain. + // On MQTT v3.1.1, rap_value is always rap::dont. + auto deliver = [&](session_state& ss, subscription& sub, auto const& auth_users) { + // See if this session is authorized to subscribe this topic + { + std::shared_lock g_sec{ mtx_security_ }; + auto access = security_.auth_sub_user(auth_users, ss.get_username()); + if (access != security::authorization::type::allow) + return false; + } + pub::opts new_opts = std::min(opts.get_qos(), sub.opts.get_qos()); + if (sub.opts.get_rap() == sub::rap::retain && opts.get_retain() == pub::retain::yes) { + new_opts |= pub::retain::yes; + } + + if (sub.sid) { + props.push_back(property::subscription_identifier(boost::numeric_cast(*sub.sid))); + ss.deliver(timer_ioc_, topic, payload, new_opts, props); + props.pop_back(); + } else { + ss.deliver(timer_ioc_, topic, payload, new_opts, props); + } + return true; + }; + + // share_name topic_filter + std::set> sent; + + { + std::shared_lock g{ mtx_subs_map_ }; + subs_map_.modify(topic, [&](std::string const& /*key*/, subscription& sub) { + if (sub.sharename.empty()) { + // Non shared subscriptions + + // If NL (no local) subscription option is set and + // publisher is the same as subscriber, then skip it. + if (sub.opts.get_nl() == sub::nl::yes && sub.ss.get().client_id() == source_ss.client_id()) + return; + if (deliver(sub.ss.get(), sub, auth_users)) + matched = true; + } else { + // Shared subscriptions + bool inserted; + std::tie(std::ignore, inserted) = sent.emplace(sub.sharename, sub.topic); + if (inserted) { + if (auto ssr_sub_opt = shared_targets_.get_target(sub.sharename, sub.topic)) { + auto [ssr, sub] = *ssr_sub_opt; + if (deliver(ssr.get(), sub, auth_users)) + matched = true; + } + } + } + }); + } + + std::optional message_expiry_interval; + if (source_ss.get_protocol_version() == protocol_version::v5) { + for (auto const& prop : props) { + prop.visit(overload{ [&](property::message_expiry_interval const& v) { + message_expiry_interval.emplace(std::chrono::seconds(v.val())); + }, + [&](auto const&) {} }); + } + } + + /* + * If the message is marked as being retained, then we + * keep it in case a new subscription is added that matches + * this topic. + * + * @note: The MQTT standard 3.3.1.3 RETAIN makes it clear that + * retained messages are global based on the topic, and + * are not scoped by the client id. So any client may + * publish a retained message on any topic, and the most + * recently published retained message on a particular + * topic is the message that is stored on the server. + * + * @note: The standard doesn't make it clear that publishing + * a message with zero length, but the retain flag not + * set, does not result in any existing retained message + * being removed. However, internet searching indicates + * that most brokers have opted to keep retained messages + * when receiving payload of zero bytes, unless the so + * received message has the retain flag set, in which case + * the retained message is removed. + */ + if (opts.get_retain() == pub::retain::yes) { + if (payload.empty()) { + std::lock_guard g(mtx_retains_); + retains_.erase(topic); + } else { + std::shared_ptr tim_message_expiry; + if (message_expiry_interval) { + tim_message_expiry = std::make_shared(timer_ioc_, *message_expiry_interval); + tim_message_expiry->async_wait([this, topic = topic, wp = std::weak_ptr(tim_message_expiry)]( + boost::system::error_code const& ec) { + if (auto sp = wp.lock()) { + if (!ec) { + retains_.erase(topic); + } + } + }); + } + + std::lock_guard g(mtx_retains_); + retains_.insert_or_assign( + topic, retain_type{ topic, force_move(payload), force_move(props), opts.get_qos(), tim_message_expiry }); + } + } + return matched; + } + + void puback_handler(epsp_type epsp, packet_id_type packet_id, puback_reason_code /*reason_code*/, properties /*props*/ + ) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + + std::shared_lock g(mtx_sessions_); + auto& idx = sessions_.template get(); + auto it = idx.find(epsp); + + // broker uses async_* APIs + // If broker erase a connection, then async_force_disconnect() + // and/or async_force_disconnect () is called. + // During async operation, spep is valid but it has already been + // erased from sessions_ + if (it == idx.end()) + return; + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& ss = const_cast&>(**it); + ss.erase_inflight_message_by_packet_id(packet_id); + ss.send_offline_messages_by_packet_id_release(); + } + + void pubrec_handler(epsp_type epsp, packet_id_type packet_id, pubrec_reason_code reason_code, properties /*props*/ + ) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + + std::shared_lock g(mtx_sessions_); + auto& idx = sessions_.template get(); + auto it = idx.find(epsp); + + // broker uses async_* APIs + // If broker erase a connection, then async_force_disconnect() + // and/or async_force_disconnect () is called. + // During async operation, spep is valid but it has already been + // erased from sessions_ + if (it == idx.end()) + return; + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& ss = const_cast&>(**it); + + if (make_error_code(reason_code)) + return; + auto rc = [&] { + ss.erase_inflight_message_by_packet_id(packet_id); + if (!epsp.is_publish_processing(packet_id)) { + return pubrel_reason_code::packet_identifier_not_found; + } else { + return pubrel_reason_code::success; + } + }(); + + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: + epsp.async_send(v3_1_1::pubrel_packet{ packet_id }, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + case protocol_version::v5: { + auto packet = [&] { + if (rc == pubrel_reason_code::success) { + if (pubrel_props_.empty()) { + return v5::pubrel_packet{ packet_id }; + } else { + return v5::pubrel_packet{ packet_id, rc, pubrel_props_ }; + } + } else { + if (pubrel_props_.empty()) { + return v5::pubrel_packet{ packet_id, rc }; + } else { + return v5::pubrel_packet{ packet_id, rc, pubrel_props_ }; + } + } + }(); + epsp.async_send(force_move(packet), [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + } break; + default: + BOOST_ASSERT(false); + break; + } + } + + void pubrel_handler(epsp_type epsp, packet_id_type packet_id, pubrel_reason_code reason_code, properties /*props*/ + ) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + + std::shared_lock g(mtx_sessions_); + auto& idx = sessions_.template get(); + auto it = idx.find(epsp); + + // broker uses async_* APIs + // If broker erase a connection, then async_force_disconnect() + // and/or async_force_disconnect () is called. + // During async operation, spep is valid but it has already been + // erased from sessions_ + if (it == idx.end()) + return; + + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: + epsp.async_send(v3_1_1::pubcomp_packet{ packet_id }, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + case protocol_version::v5: { + auto packet = [&] { + if (reason_code == pubrel_reason_code::success) { + if (pubcomp_props_.empty()) { + return v5::pubcomp_packet{ packet_id }; + } else { + return v5::pubcomp_packet{ packet_id, pubcomp_reason_code::success, pubcomp_props_ }; + } + } else { + if (pubcomp_props_.empty()) { + return v5::pubcomp_packet{ packet_id, static_cast(reason_code) }; + } else { + return v5::pubcomp_packet{ packet_id, static_cast(reason_code), pubcomp_props_ }; + } + } + }(); + epsp.async_send(force_move(packet), [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + } break; + default: + BOOST_ASSERT(false); + break; + } + } + + void pubcomp_handler(epsp_type epsp, packet_id_type packet_id, pubcomp_reason_code /*reason_code*/, properties /*props*/ + ) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + + std::shared_lock g(mtx_sessions_); + auto& idx = sessions_.template get(); + auto it = idx.find(epsp); + + // broker uses async_* APIs + // If broker erase a connection, then async_force_disconnect() + // and/or async_force_disconnect () is called. + // During async operation, spep is valid but it has already been + // erased from sessions_ + if (it == idx.end()) + return; + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& ss = const_cast&>(**it); + ss.erase_inflight_message_by_packet_id(packet_id); + ss.send_offline_messages_by_packet_id_release(); + } + + void subscribe_handler(epsp_type epsp, + packet_id_type packet_id, + std::vector const& entries, + properties props) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + std::shared_lock g(mtx_sessions_); + auto& idx = sessions_.template get(); + auto it = idx.find(epsp); + + // broker uses async_* APIs + // If broker erase a connection, then async_force_disconnect() + // and/or async_force_disconnect () is called. + // During async operation, spep is valid but it has already been + // erased from sessions_ + if (it == idx.end()) + return; + + // The element of sessions_ must have longer lifetime + // than corresponding subscription. + // Because the subscription store the reference of the element. + std::optional> ssr_opt; + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& ss = const_cast&>(**it); + ssr_opt.emplace(ss); + + BOOST_ASSERT(ssr_opt); + session_state_ref ssr{ *ssr_opt }; + + auto publish_proc = [this, &ssr, &epsp](retain_type const& r, qos qos_value, std::optional sid) { + auto props = r.props; + if (sid) { + props.push_back(property::subscription_identifier(std::uint32_t(*sid))); + } + if (r.tim_message_expiry) { + auto d = std::chrono::duration_cast(r.tim_message_expiry->expiry() - + std::chrono::steady_clock::now()) + .count(); + for (auto& prop : props) { + prop.visit(overload{ + [&](property::message_expiry_interval& v) { v = property::message_expiry_interval(static_cast(d)); }, + [&](auto&) {} }); + } + } + ssr.get().publish(epsp, timer_ioc_, r.topic, r.payload, std::min(r.qos_value, qos_value) | pub::retain::yes, props); + }; + + std::vector> retain_deliver; + retain_deliver.reserve(entries.size()); + + // subscription identifier + std::optional sid; + + // An in-order list of qos settings, used to send the reply. + // The MQTT protocol 3.1.1 - 3.8.4 Response - paragraph 6 + // allows the server to grant a lower QOS than requested + // So we reply with the QOS setting that was granted + // not the one requested. + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: { + std::vector res; + res.reserve(entries.size()); + for (auto& e : entries) { + if (!e || [&] { + std::shared_lock g_sec{ mtx_security_ }; + return security_.is_subscribe_authorized(ss.get_username(), e.topic()); + }()) { + res.emplace_back(qos_to_suback_return_code(e.opts().get_qos())); // converts to granted_qos_x + ssr.get().subscribe(e.sharename(), e.topic(), e.opts(), [&] { + std::shared_lock g(mtx_retains_); + retains_.find(e.topic(), [&](retain_type const& r) { + retain_deliver.emplace_back( + [&publish_proc, &r, qos_value = e.opts().get_qos(), sid] { publish_proc(r, qos_value, sid); }); + }); + }); + } else { + // User not authorized to subscribe to topic filter + res.emplace_back(suback_return_code::failure); + } + } + // Acknowledge the subscriptions, and the registered QOS settings + epsp.async_send(v3_1_1::suback_packet{ packet_id, force_move(res) }, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + } break; + case protocol_version::v5: { + // Get subscription identifier + for (auto const& prop : props) { + prop.visit(overload{ [&](property::subscription_identifier const& v) { + // TBD error if 0 + if (v.val() != 0) { + sid.emplace(v.val()); + } + }, + [&](auto const&) {} }); + if (sid) + break; + } + + std::vector res; + res.reserve(entries.size()); + for (auto& e : entries) { + if (e) { + if ([&] { + std::shared_lock g_sec{ mtx_security_ }; + return security_.is_subscribe_authorized(ss.get_username(), e.topic()); + }()) { + res.emplace_back(qos_to_suback_reason_code(e.opts().get_qos())); // converts to granted_qos_x + ssr.get().subscribe( + e.sharename(), e.topic(), e.opts(), + [&] { + std::shared_lock g(mtx_retains_); + retains_.find(e.topic(), [&](retain_type const& r) { + retain_deliver.emplace_back( + [&publish_proc, &r, qos_value = e.opts().get_qos(), sid] { publish_proc(r, qos_value, sid); }); + }); + }, + sid); + } else { + // User not authorized to subscribe to topic filter + res.emplace_back(suback_reason_code::not_authorized); + } + } else { + res.emplace_back(suback_reason_code::topic_filter_invalid); + } + } + if (h_subscribe_props_) + h_subscribe_props_(props); + // Acknowledge the subscriptions, and the registered QOS settings + epsp.async_send(v5::suback_packet{ packet_id, force_move(res), suback_props_ }, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + } break; + default: + BOOST_ASSERT(false); + break; + } + + for (auto const& f : retain_deliver) { + f(); + } + } + + void unsubscribe_handler(epsp_type epsp, + packet_id_type packet_id, + std::vector entries, + properties props) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + + std::shared_lock g(mtx_sessions_); + auto& idx = sessions_.template get(); + auto it = idx.find(epsp); + + // broker uses async_* APIs + // If broker erase a connection, then async_force_disconnect() + // and/or async_force_disconnect () is called. + // During async operation, spep is valid but it has already been + // erased from sessions_ + if (it == idx.end()) + return; + + // The element of sessions_ must have longer lifetime + // than corresponding subscription. + // Because the subscription store the reference of the element. + std::optional> ssr_opt; + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& ss = const_cast&>(**it); + ssr_opt.emplace(ss); + + BOOST_ASSERT(ssr_opt); + session_state_ref ssr{ *ssr_opt }; + + // For each subscription that this connection has + // Compare against the list of topic filters, and remove + // the subscription if the topic filter is in the list. + for (auto const& e : entries) { + ssr.get().unsubscribe(e.sharename(), e.topic()); + } + + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: + epsp.async_send(v3_1_1::unsuback_packet{ packet_id }, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + case protocol_version::v5: + if (h_unsubscribe_props_) + h_unsubscribe_props_(props); + epsp.async_send( + v5::unsuback_packet{ packet_id, std::vector(entries.size(), unsuback_reason_code::success), + unsuback_props_ }, + [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + default: + BOOST_ASSERT(false); + break; + } + } + + void pingreq_handler(epsp_type epsp) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + + if (!pingresp_) + return; + + switch (epsp.get_protocol_version()) { + case protocol_version::v3_1_1: + epsp.async_send(v3_1_1::pingresp_packet{}, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + case protocol_version::v5: + epsp.async_send(v5::pingresp_packet{}, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + default: + BOOST_ASSERT(false); + break; + } + } + + void disconnect_handler(epsp_type epsp, disconnect_reason_code rc, properties /*props*/ + ) { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "disconnect_handler"; + if (delay_disconnect_) { + tim_disconnect_.expires_after(*delay_disconnect_); + tim_disconnect_.wait(); + } + close_proc(force_move(epsp), rc == disconnect_reason_code::disconnect_with_will_message, rc); + } + + struct close_proc_no_lock_op { + close_proc_no_lock_op(this_type& brk, epsp_type epsp, bool send_will, std::optional rc_opt) + : brk{ brk }, epsp{ force_move(epsp) }, send_will{ send_will }, rc_opt{ rc_opt } {} + + this_type& brk; + epsp_type epsp; + bool send_will; + std::optional rc_opt; + enum { close, complete } state = close; + + template + void operator()(Self& self) { + auto do_send_will = [&](session_state& ss) { + if (send_will) { + ss.send_will(); + } else { + ss.clear_will(); + } + }; + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "close_proc_no_lock"; + + auto& idx = brk.sessions_.template get(); + auto it = idx.find(epsp); + + // act_sess_it == act_sess_idx.end() could happen if broker accepts + // the session from client but the client closes the session before sending + // MQTT `CONNECT` message. + // In this case, do nothing is correct behavior. + if (it == idx.end()) { + self.complete(false); + return; + } + + if ((*it)->remain_after_close()) { + idx.modify(it, [&](std::shared_ptr>& sssp) { + do_send_will(*sssp); + if (rc_opt) { + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "disconnect_and_close() cid:" << sssp->client_id(); + auto a_epsp{ epsp }; + disconnect_and_close(a_epsp, (*it)->get_protocol_version(), *rc_opt, as::append(force_move(self), sssp)); + } else { + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "close cid:" << sssp->client_id(); + auto a_epsp{ epsp }; + a_epsp.async_close(as::append(force_move(self), sssp)); + } + }); + } else { + auto sssp{ force_move(idx.extract(it).value()) }; + do_send_will(*sssp); + if (rc_opt) { + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "disconnect_and_close() cid:" << sssp->client_id(); + auto a_epsp{ epsp }; + disconnect_and_close(a_epsp, sssp->get_protocol_version(), *rc_opt, + as::consign(as::append(force_move(self), nullptr), sssp)); + } else { + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "close cid:" << sssp->client_id(); + auto a_epsp{ epsp }; + a_epsp.async_close(as::consign(as::append(force_move(self), nullptr), sssp)); + } + } + } + + template + void operator()(Self& self, std::shared_ptr> sssp) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) + << "disconnect(optional)_and_closed"; + if (sssp) { + // sessions_ index is never changed because owner_less + // remains original order even if shared_count would be zero + sssp->become_offline(epsp, [&brk = this->brk](std::shared_ptr const& sp_tim) { + // lock for expire (async) + std::lock_guard g(brk.mtx_sessions_); + brk.sessions_.template get().erase(sp_tim); + }); + self.complete(true); + } else { + self.complete(false); + } + } + }; + + /** + * @brief close_proc_no_lock - clean up a connection that has been closed. + * + * @param ep - The underlying server (of whichever type) that is disconnecting. + * @param send_will - Whether to publish this connections last will + * @return true if offline session is remained, otherwise false + */ + // TODO: Maybe change the name of this function. + template + auto close_proc_no_lock(epsp_type epsp, + bool send_will, + std::optional rc_opt, + CompletionToken&& token) { + auto exe = epsp.get_executor(); + return as::async_compose( + close_proc_no_lock_op{ *this, force_move(epsp), send_will, force_move(rc_opt) }, token, exe); + } + + /** + * @brief close_proc - clean up a connection that has been closed. + * + * @param ep - The underlying server (of whichever type) that is disconnecting. + * @param send_will - Whether to publish this connections last will + * @param rc - Reason Code for send pack DISCONNECT + * @return true if offline session is remained, otherwise false + */ + // TODO: Maybe change the name of this function. + void close_proc(epsp_type epsp, bool send_will, std::optional rc_opt = std::nullopt) { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "close_proc"; + std::lock_guard g(mtx_sessions_); + close_proc_no_lock(force_move(epsp), send_will, rc_opt, [](bool) {}); + } + + void auth_handler(epsp_type epsp, auth_reason_code /*rc*/, properties props) { + auto usg = unique_scope_guard([&] { async_read_packet(force_move(epsp)); }); + + if (h_auth_props_) + h_auth_props_(force_move(props)); + } + + struct disconnect_and_close_op { + disconnect_and_close_op(epsp_type epsp, protocol_version version, disconnect_reason_code rc) + : epsp{ force_move(epsp) }, version{ version }, rc{ rc }, state{ [&] { + if (version == protocol_version::v3_1_1) { + return close; + } else { + return disconnect; + } + }() } {} + + epsp_type epsp; + protocol_version version; + disconnect_reason_code rc; + enum { disconnect, close, complete } state; + + template + void operator()(Self& self, error_code = {}) { + switch (state) { + case disconnect: { + state = close; + auto a_epsp{ epsp }; + a_epsp.async_send(v5::disconnect_packet{ rc, properties{} }, force_move(self)); + } break; + case close: { + state = complete; + auto a_epsp{ epsp }; + a_epsp.async_close(force_move(self)); + } break; + case complete: + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "closed"; + self.complete(); + break; + } + } + }; + + template + static auto disconnect_and_close(epsp_type epsp, + protocol_version version, + disconnect_reason_code rc, + CompletionToken&& token) { + auto exe = epsp.get_executor(); + return as::async_compose(disconnect_and_close_op{ force_move(epsp), version, rc }, token, exe); + } + +private: + as::io_context& timer_ioc_; ///< The boost asio context to run this broker on. + as::steady_timer tim_disconnect_; ///< Used to delay disconnect handling for testing + std::optional delay_disconnect_; ///< Used to delay disconnect handling for testing + + // Authorization and authentication settings + mutable mutex mtx_security_; + security security_; + + mutable mutex mtx_subs_map_; + sub_con_map subs_map_; ///< subscription information + shared_target shared_targets_; ///< shared subscription targets + + ///< Map of active client id and connections + /// session_state has references of subs_map_ and shared_targets_. + /// because session_state (member of sessions_) has references of subs_map_ and shared_targets_. + mutable mutex mtx_sessions_; + session_states sessions_; + + mutable mutex mtx_retains_; + retained_messages retains_; ///< A list of messages retained so they can be sent to newly subscribed clients. + + // MQTTv5 members + properties connack_props_; + properties suback_props_; + properties unsuback_props_; + properties puback_props_; + properties pubrec_props_; + properties pubrel_props_; + properties pubcomp_props_; + std::function h_connect_props_; + std::function h_disconnect_props_; + std::function h_publish_props_; + std::function h_puback_props_; + std::function h_pubrec_props_; + std::function h_pubrel_props_; + std::function h_pubcomp_props_; + std::function h_subscribe_props_; + std::function h_unsubscribe_props_; + std::function h_auth_props_; + bool pingresp_ = true; + bool connack_ = true; + bool recycling_allocator_; +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_BROKER_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/constant.hpp b/exes/mqtt-bridge/tests/inc/broker/constant.hpp new file mode 100644 index 0000000000..b1398bef03 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/constant.hpp @@ -0,0 +1,19 @@ +#pragma once +// Copyright Wouter van Kleunen 2022 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_CONSTANT_HPP) +#define ASYNC_MQTT_BROKER_CONSTANT_HPP + +#include + +namespace async_mqtt { + +static constexpr std::size_t max_cn_size = 0xffff; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_CONSTANT_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/endpoint_variant.hpp b/exes/mqtt-bridge/tests/inc/broker/endpoint_variant.hpp new file mode 100644 index 0000000000..87f0d5cba0 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/endpoint_variant.hpp @@ -0,0 +1,218 @@ +#pragma once +// Copyright Takatoshi Kondo 2022 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_ENDPOINT_VARIANT_HPP) +#define ASYNC_MQTT_BROKER_ENDPOINT_VARIANT_HPP + +#include +#include +#include +#include + +namespace async_mqtt { + +template +struct basic_endpoint_variant : std::variant>...> { + using this_type = basic_endpoint_variant; + using base_type = std::variant>...>; + using base_type::base_type; + + static constexpr role role_value = Role; + static constexpr std::size_t packet_id_bytes = PacketIdBytes; + + template + basic_endpoint const& as() const { + return *std::get>>(*this); + } + template + basic_endpoint& as() { + return *std::get>>(*this); + } + + struct weak_type : std::variant>...> { + using base_type = std::variant>...>; + using base_type::base_type; + using shared_type = std::variant>...>; + this_type lock() { + return std::visit([&](auto& wp) -> this_type { return wp.lock(); }, *this); + } + bool operator<(weak_type const& other) const { + return std::visit([&](auto& lhs) { return std::visit([&](auto& rhs) { return lhs.owner_before(rhs); }, other); }, + *this); + } + }; +}; + +template +using endpoint_variant = basic_endpoint_variant; + +template +class epsp_wrap { +public: + using epsp_type = Epsp; + using this_type = epsp_wrap; + static constexpr std::size_t packet_id_bytes = epsp_type::packet_id_bytes; + using packet_variant_type = basic_packet_variant; + using weak_type = typename epsp_type::weak_type; + + epsp_wrap(epsp_type epsp) : epsp_{ force_move(epsp) } {} + + template + decltype(auto) visit(Func&& func) const { + return std::visit([&](auto& ep) -> decltype(auto) { return std::forward(func)(*ep); }, epsp_); + } + + template + decltype(auto) visit(Func&& func) { + return std::visit([&](auto& ep) -> decltype(auto) { return std::forward(func)(*ep); }, epsp_); + } + + template + void dispatch(Func&& func) const { + visit([&](auto& ep) { as::dispatch(as::bind_executor(ep.get_executor(), std::forward(func))); }); + } + + as::any_io_executor get_executor() { + return visit([&](auto& ep) -> as::any_io_executor { return ep.get_executor(); }); + } + + // async functions + + template + auto async_acquire_unique_packet_id(CompletionToken&& token) { + return visit([&](auto& ep) { return ep.async_acquire_unique_packet_id(std::forward(token)); }); + } + + template + auto async_acquire_unique_packet_id_wait_until(CompletionToken&& token) { + return visit( + [&](auto& ep) { return ep.async_acquire_unique_packet_id_wait_until(std::forward(token)); }); + } + + template + auto async_register_packet_id(typename basic_packet_id_type::type packet_id, CompletionToken&& token) { + return visit([&](auto& ep) { return ep.async_register_packet_id(packet_id, std::forward(token)); }); + } + + template + auto async_release_packet_id(typename basic_packet_id_type::type packet_id, CompletionToken&& token) { + return visit([&](auto& ep) { return ep.async_release_packet_id(packet_id, std::forward(token)); }); + } + + template + auto async_send(Packet&& packet, CompletionToken&& token) { + return visit( + [&](auto& ep) { return ep.async_send(std::forward(packet), std::forward(token)); }); + } + + template + auto async_recv(CompletionToken&& token) { + return visit([&](auto& ep) { return ep.async_recv(std::forward(token)); }); + } + + template + auto async_close(CompletionToken&& token) { + return visit([&](auto& ep) { return ep.async_close(std::forward(token)); }); + } + + template + auto async_restore_packets(std::vector> pvs, CompletionToken&& token) { + return visit([&](auto& ep) { return ep.async_restore_packets(force_move(pvs), std::forward(token)); }); + } + + template + auto async_get_stored_packets(CompletionToken&& token) { + return visit([&](auto& ep) { return ep.async_get_stored_packets(std::forward(token)); }); + } + + // sync APIs (Thread unsafe without strand) + + std::optional::type> acquire_unique_packet_id() { + return visit([&](auto& ep) { return ep.acquire_unique_packet_id(); }); + } + + bool register_packet_id(typename basic_packet_id_type::type pid) { + return visit([&](auto& ep) { return ep.register_packet_id(pid); }); + } + + void release_packet_id(typename basic_packet_id_type::type pid) { + visit([&](auto& ep) { return ep.release_packet_id(pid); }); + } + + /** + * @brief Get processed but not released QoS2 packet ids + * This function should be called after disconnection + * @return set of packet_ids + */ + std::set::type> get_qos2_publish_handled_pids() const { + return visit([&](auto& ep) { return ep.get_qos2_publish_handled_pids(); }); + } + + /** + * @brief Restore processed but not released QoS2 packet ids + * This function should be called before receive the first publish + * @param pids packet ids + */ + void restore_qos2_publish_handled_pids(std::set::type> pids) { + visit([&](auto& ep) { return ep.restore_qos2_publish_handled_pids(pids); }); + } + + void restore_packets(std::vector> pvs) { + visit([&](auto& ep) { ep.restore(pvs); }); + } + + std::vector> get_stored_packets() const { + return visit([&](auto& ep) { return ep.get_stored_packets(); }); + } + + void set_preauthed_user_name(std::optional user_name) { preauthed_user_name_ = force_move(user_name); } + + std::optional const& get_preauthed_user_name() const { return preauthed_user_name_; } + + protocol_version get_protocol_version() const { + if (!protocol_version_) { + // On multi threaded environment, + // The following code requires running in ep's strand. + // It is safe because it is always called from ep's strand + // in connect_handler for the first time. + protocol_version_.emplace(visit([&](auto& ep) { return ep.get_protocol_version(); })); + } + return *protocol_version_; + } + + bool is_publish_processing(typename basic_packet_id_type::type pid) const { + return visit([&](auto& ep) { return ep.is_publish_processing(pid); }); + } + + void set_client_id(std::string cid) { client_id_ = force_move(cid); } + + std::string const& get_client_id() const { return client_id_; } + + operator bool() const { + return std::visit([&](auto const& epsp) { return static_cast(epsp); }, epsp_); + } + + operator weak_type() const { + return std::visit([&](auto& ep) -> weak_type { return ep; }, epsp_); + } + + void const* get_address() const { + return std::visit([&](auto const& epsp) -> void const* { return epsp.get(); }, epsp_); + } + + epsp_type get_epvsp() { return epsp_; } + +private: + epsp_type epsp_; + std::string client_id_; + std::optional preauthed_user_name_; + mutable std::optional protocol_version_; +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_ENDPOINT_VARIANT_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/external/picosha2.h b/exes/mqtt-bridge/tests/inc/broker/external/picosha2.h new file mode 100644 index 0000000000..373a6cf705 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/external/picosha2.h @@ -0,0 +1,381 @@ +#pragma once +/* +The MIT License (MIT) + +Copyright (C) 2017 okdshin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef PICOSHA2_H +#define PICOSHA2_H +// picosha2:20140213 + +#ifndef PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR +#define PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR 1048576 //=1024*1024: default is 1MB memory +#endif + +#include +#include +#include +#include +#include +#include +namespace picosha2 { +typedef unsigned long word_t; +typedef unsigned char byte_t; + +static const size_t k_digest_size = 32; + +namespace detail { +inline byte_t mask_8bit(byte_t x) { + return x & 0xff; +} + +inline word_t mask_32bit(word_t x) { + return x & 0xffffffff; +} + +const word_t add_constant[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, + 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, + 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, + 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, 0x1e376c08, + 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 +}; + +const word_t initial_message_digest[8] = { 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19 }; + +inline word_t ch(word_t x, word_t y, word_t z) { + return (x & y) ^ ((~x) & z); +} + +inline word_t maj(word_t x, word_t y, word_t z) { + return (x & y) ^ (x & z) ^ (y & z); +} + +inline word_t rotr(word_t x, std::size_t n) { + assert(n < 32); + return mask_32bit((x >> n) | (x << (32 - n))); +} + +inline word_t bsig0(word_t x) { + return rotr(x, 2) ^ rotr(x, 13) ^ rotr(x, 22); +} + +inline word_t bsig1(word_t x) { + return rotr(x, 6) ^ rotr(x, 11) ^ rotr(x, 25); +} + +inline word_t shr(word_t x, std::size_t n) { + assert(n < 32); + return x >> n; +} + +inline word_t ssig0(word_t x) { + return rotr(x, 7) ^ rotr(x, 18) ^ shr(x, 3); +} + +inline word_t ssig1(word_t x) { + return rotr(x, 17) ^ rotr(x, 19) ^ shr(x, 10); +} + +template +void hash256_block(RaIter1 message_digest, RaIter2 first, RaIter2 last) { + assert(first + 64 == last); + static_cast(last); // for avoiding unused-variable warning + word_t w[64]; + std::fill(w, w + 64, word_t(0)); + for (std::size_t i = 0; i < 16; ++i) { + w[i] = (static_cast(mask_8bit(*(first + i * 4))) << 24) | + (static_cast(mask_8bit(*(first + i * 4 + 1))) << 16) | + (static_cast(mask_8bit(*(first + i * 4 + 2))) << 8) | + (static_cast(mask_8bit(*(first + i * 4 + 3)))); + } + for (std::size_t i = 16; i < 64; ++i) { + w[i] = mask_32bit(ssig1(w[i - 2]) + w[i - 7] + ssig0(w[i - 15]) + w[i - 16]); + } + + word_t a = *message_digest; + word_t b = *(message_digest + 1); + word_t c = *(message_digest + 2); + word_t d = *(message_digest + 3); + word_t e = *(message_digest + 4); + word_t f = *(message_digest + 5); + word_t g = *(message_digest + 6); + word_t h = *(message_digest + 7); + + for (std::size_t i = 0; i < 64; ++i) { + word_t temp1 = h + bsig1(e) + ch(e, f, g) + add_constant[i] + w[i]; + word_t temp2 = bsig0(a) + maj(a, b, c); + h = g; + g = f; + f = e; + e = mask_32bit(d + temp1); + d = c; + c = b; + b = a; + a = mask_32bit(temp1 + temp2); + } + *message_digest += a; + *(message_digest + 1) += b; + *(message_digest + 2) += c; + *(message_digest + 3) += d; + *(message_digest + 4) += e; + *(message_digest + 5) += f; + *(message_digest + 6) += g; + *(message_digest + 7) += h; + for (std::size_t i = 0; i < 8; ++i) { + *(message_digest + i) = mask_32bit(*(message_digest + i)); + } +} + +} // namespace detail + +template +void output_hex(InIter first, InIter last, std::ostream& os) { + os.setf(std::ios::hex, std::ios::basefield); + while (first != last) { + os.width(2); + os.fill('0'); + os << static_cast(*first); + ++first; + } + os.setf(std::ios::dec, std::ios::basefield); +} + +template +void bytes_to_hex_string(InIter first, InIter last, std::string& hex_str) { + std::ostringstream oss; + output_hex(first, last, oss); + hex_str.assign(oss.str()); +} + +template +void bytes_to_hex_string(const InContainer& bytes, std::string& hex_str) { + bytes_to_hex_string(bytes.begin(), bytes.end(), hex_str); +} + +template +std::string bytes_to_hex_string(InIter first, InIter last) { + std::string hex_str; + bytes_to_hex_string(first, last, hex_str); + return hex_str; +} + +template +std::string bytes_to_hex_string(const InContainer& bytes) { + std::string hex_str; + bytes_to_hex_string(bytes, hex_str); + return hex_str; +} + +class hash256_one_by_one { +public: + hash256_one_by_one() { init(); } + + void init() { + buffer_.clear(); + std::fill(data_length_digits_, data_length_digits_ + 4, word_t(0)); + std::copy(detail::initial_message_digest, detail::initial_message_digest + 8, h_); + } + + template + void process(RaIter first, RaIter last) { + add_to_data_length(static_cast(std::distance(first, last))); + std::copy(first, last, std::back_inserter(buffer_)); + std::size_t i = 0; + for (; i + 64 <= buffer_.size(); i += 64) { + detail::hash256_block(h_, buffer_.begin() + i, buffer_.begin() + i + 64); + } + buffer_.erase(buffer_.begin(), buffer_.begin() + i); + } + + void finish() { + byte_t temp[64]; + std::fill(temp, temp + 64, byte_t(0)); + std::size_t remains = buffer_.size(); + std::copy(buffer_.begin(), buffer_.end(), temp); + temp[remains] = 0x80; + + if (remains > 55) { + std::fill(temp + remains + 1, temp + 64, byte_t(0)); + detail::hash256_block(h_, temp, temp + 64); + std::fill(temp, temp + 64 - 4, byte_t(0)); + } else { + std::fill(temp + remains + 1, temp + 64 - 4, byte_t(0)); + } + + write_data_bit_length(&(temp[56])); + detail::hash256_block(h_, temp, temp + 64); + } + + template + void get_hash_bytes(OutIter first, OutIter last) const { + for (const word_t* iter = h_; iter != h_ + 8; ++iter) { + for (std::size_t i = 0; i < 4 && first != last; ++i) { + *(first++) = detail::mask_8bit(static_cast((*iter >> (24 - 8 * i)))); + } + } + } + +private: + void add_to_data_length(word_t n) { + word_t carry = 0; + data_length_digits_[0] += n; + for (std::size_t i = 0; i < 4; ++i) { + data_length_digits_[i] += carry; + if (data_length_digits_[i] >= 65536u) { + carry = data_length_digits_[i] >> 16; + data_length_digits_[i] &= 65535u; + } else { + break; + } + } + } + void write_data_bit_length(byte_t* begin) { + word_t data_bit_length_digits[4]; + std::copy(data_length_digits_, data_length_digits_ + 4, data_bit_length_digits); + + // convert byte length to bit length (multiply 8 or shift 3 times left) + word_t carry = 0; + for (std::size_t i = 0; i < 4; ++i) { + word_t before_val = data_bit_length_digits[i]; + data_bit_length_digits[i] <<= 3; + data_bit_length_digits[i] |= carry; + data_bit_length_digits[i] &= 65535u; + carry = (before_val >> (16 - 3)) & 65535u; + } + + // write data_bit_length + for (int i = 3; i >= 0; --i) { + (*begin++) = static_cast(data_bit_length_digits[i] >> 8); + (*begin++) = static_cast(data_bit_length_digits[i]); + } + } + std::vector buffer_; + word_t data_length_digits_[4]; // as 64bit integer (16bit x 4 integer) + word_t h_[8]; +}; + +inline void get_hash_hex_string(const hash256_one_by_one& hasher, std::string& hex_str) { + byte_t hash[k_digest_size]; + hasher.get_hash_bytes(hash, hash + k_digest_size); + return bytes_to_hex_string(hash, hash + k_digest_size, hex_str); +} + +inline std::string get_hash_hex_string(const hash256_one_by_one& hasher) { + std::string hex_str; + get_hash_hex_string(hasher, hex_str); + return hex_str; +} + +namespace impl { +template +void hash256_impl(RaIter first, RaIter last, OutIter first2, OutIter last2, int, std::random_access_iterator_tag) { + hash256_one_by_one hasher; + // hasher.init(); + hasher.process(first, last); + hasher.finish(); + hasher.get_hash_bytes(first2, last2); +} + +template +void hash256_impl(InputIter first, InputIter last, OutIter first2, OutIter last2, int buffer_size, std::input_iterator_tag) { + std::vector buffer(buffer_size); + hash256_one_by_one hasher; + // hasher.init(); + while (first != last) { + int size = buffer_size; + for (int i = 0; i != buffer_size; ++i, ++first) { + if (first == last) { + size = i; + break; + } + buffer[i] = *first; + } + hasher.process(buffer.begin(), buffer.begin() + size); + } + hasher.finish(); + hasher.get_hash_bytes(first2, last2); +} +} // namespace impl + +template +void hash256(InIter first, + InIter last, + OutIter first2, + OutIter last2, + int buffer_size = PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR) { + picosha2::impl::hash256_impl(first, last, first2, last2, buffer_size, + typename std::iterator_traits::iterator_category()); +} + +template +void hash256(InIter first, InIter last, OutContainer& dst) { + hash256(first, last, dst.begin(), dst.end()); +} + +template +void hash256(const InContainer& src, OutIter first, OutIter last) { + hash256(src.begin(), src.end(), first, last); +} + +template +void hash256(const InContainer& src, OutContainer& dst) { + hash256(src.begin(), src.end(), dst.begin(), dst.end()); +} + +template +void hash256_hex_string(InIter first, InIter last, std::string& hex_str) { + byte_t hashed[k_digest_size]; + hash256(first, last, hashed, hashed + k_digest_size); + std::ostringstream oss; + output_hex(hashed, hashed + k_digest_size, oss); + hex_str.assign(oss.str()); +} + +template +std::string hash256_hex_string(InIter first, InIter last) { + std::string hex_str; + hash256_hex_string(first, last, hex_str); + return hex_str; +} + +inline void hash256_hex_string(const std::string& src, std::string& hex_str) { + hash256_hex_string(src.begin(), src.end(), hex_str); +} + +template +void hash256_hex_string(const InContainer& src, std::string& hex_str) { + hash256_hex_string(src.begin(), src.end(), hex_str); +} + +template +std::string hash256_hex_string(const InContainer& src) { + return hash256_hex_string(src.begin(), src.end()); +} +template +void hash256(std::ifstream& f, OutIter first, OutIter last) { + hash256(std::istreambuf_iterator(f), std::istreambuf_iterator(), first, last); +} +} // namespace picosha2 +#endif // PICOSHA2_H diff --git a/exes/mqtt-bridge/tests/inc/broker/fixed_core_map.hpp b/exes/mqtt-bridge/tests/inc/broker/fixed_core_map.hpp new file mode 100644 index 0000000000..fe0ac448b6 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/fixed_core_map.hpp @@ -0,0 +1,42 @@ +#pragma once +// Copyright Takatoshi Kondo 2021 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_FIXED_CORE_MAP_HPP) +#define ASYNC_MQTT_BROKER_FIXED_CORE_MAP_HPP + +#include + +#if defined(_GNU_SOURCE) + +#include +#include + +namespace async_mqtt { + +inline void map_core_to_this_thread(std::size_t core) { + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(static_cast(core), &mask); + int ret = sched_setaffinity(0, sizeof(mask), &mask); + BOOST_ASSERT(ret == 0); +} + +} // namespace async_mqtt + +#else // defined(_GNU_SOURCE) + +namespace async_mqtt { + +inline void map_core_to_this_thread(std::size_t /*core*/) { + ASYNC_MQTT_LOG("mqtt_broker", warning) << "map_core_to_this_thread() is called but do nothing"; +} + +} // namespace async_mqtt + +#endif // defined(_GNU_SOURCE) + +#endif // ASYNC_MQTT_BROKER_FIXED_CORE_MAP_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/inflight_message.hpp b/exes/mqtt-bridge/tests/inc/broker/inflight_message.hpp new file mode 100644 index 0000000000..66ecc0ceb5 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/inflight_message.hpp @@ -0,0 +1,113 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_INFLIGHT_MESSAGE_HPP) +#define ASYNC_MQTT_BROKER_INFLIGHT_MESSAGE_HPP + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace async_mqtt { + +namespace mi = boost::multi_index; + +class inflight_messages; + +class inflight_message { +public: + inflight_message(store_packet_variant packet, std::shared_ptr tim_message_expiry) + : packet_{ force_move(packet) }, tim_message_expiry_{ force_move(tim_message_expiry) } {} + + packet_id_type packet_id() const { return packet_.packet_id(); } + + template + void send(Epsp& epsp) const { + std::optional packet_opt; + if (tim_message_expiry_) { + packet_.visit(overload{ [&](v5::basic_publish_packet const& m) { + auto updated_packet = m; + auto d = std::chrono::duration_cast(tim_message_expiry_->expiry() - + std::chrono::steady_clock::now()) + .count(); + if (d < 0) + d = 0; + updated_packet.update_message_expiry_interval(static_cast(d)); + packet_opt.emplace(force_move(updated_packet)); + }, + [](auto const&) {} }); + } + epsp.register_packet_id(packet_id()); + epsp.async_send(packet_opt ? *packet_opt : packet_, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + } + + store_packet_variant const& packet() const { return packet_; } + +private: + friend class inflight_messages; + + store_packet_variant packet_; + std::shared_ptr tim_message_expiry_; +}; + +class inflight_messages { +public: + void insert(store_packet_variant packet, std::shared_ptr tim_message_expiry) { + messages_.emplace_back(force_move(packet), force_move(tim_message_expiry)); + } + + template + void send_all_messages(Epsp& epsp) { + for (auto const& ifm : messages_) { + ifm.send(epsp); + } + } + + void clear() { messages_.clear(); } + + template + decltype(auto) get() { + return messages_.get(); + } + + template + decltype(auto) get() const { + return messages_.get(); + } + +private: + using mi_inflight_message = mi::multi_index_container< + inflight_message, + mi::indexed_by< + mi::sequenced >, + mi::ordered_unique, BOOST_MULTI_INDEX_CONST_MEM_FUN(inflight_message, packet_id_type, packet_id)>, + mi::ordered_non_unique< + mi::tag, + BOOST_MULTI_INDEX_MEMBER(inflight_message, std::shared_ptr, tim_message_expiry_)> > >; + + mi_inflight_message messages_; +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_INFLIGHT_MESSAGE_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/mutex.hpp b/exes/mqtt-bridge/tests/inc/broker/mutex.hpp new file mode 100644 index 0000000000..70642fed12 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/mutex.hpp @@ -0,0 +1,19 @@ +#pragma once +// Copyright Takatoshi Kondo 2021 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_MUTEX_HPP) +#define ASYNC_MQTT_BROKER_MUTEX_HPP + +#include + +namespace async_mqtt { + +using mutex = std::shared_timed_mutex; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_MUTEX_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/offline_message.hpp b/exes/mqtt-bridge/tests/inc/broker/offline_message.hpp new file mode 100644 index 0000000000..1cbc64cc25 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/offline_message.hpp @@ -0,0 +1,176 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_OFFLINE_MESSAGE_HPP) +#define ASYNC_MQTT_BROKER_OFFLINE_MESSAGE_HPP + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace async_mqtt { + +namespace mi = boost::multi_index; + +class offline_messages; + +// The offline_message structure holds messages that have been published on a +// topic that a not-currently-connected client is subscribed to. +// When a new connection is made with the client id for this saved data, +// these messages will be published to that client, and only that client. +class offline_message { +public: + offline_message(std::string topic, + std::vector payload, + pub::opts pubopts, + properties props, + std::shared_ptr tim_message_expiry) + : topic_{ force_move(topic) }, payload_(force_move(payload)), pubopts_{ pubopts }, props_(force_move(props)), + tim_message_expiry_{ force_move(tim_message_expiry) } {} + + template + bool send(Epsp epsp, protocol_version ver) { + auto publish = [&](packet_id_type pid) { + switch (ver) { + case protocol_version::v3_1_1: + epsp.async_send(v3_1_1::publish_packet{ pid, topic_, payload_, pubopts_ }, [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", warning) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + break; + case protocol_version::v5: { + auto packet = v5::publish_packet{ pid, topic_, payload_, pubopts_, props_ }; + if (tim_message_expiry_) { + auto d = std::chrono::duration_cast(tim_message_expiry_->expiry() - + std::chrono::steady_clock::now()) + .count(); + if (d < 0) + d = 0; + packet.update_message_expiry_interval(static_cast(d)); + } + epsp.async_send(force_move(packet), [epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", warning) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << ec.message(); + } + }); + } break; + default: + BOOST_ASSERT(false); + break; + } + }; + + auto qos_value = pubopts_.get_qos(); + if (qos_value == qos::at_least_once || qos_value == qos::exactly_once) { + if (auto pid_opt = epsp.acquire_unique_packet_id()) { + publish(*pid_opt); + return true; + } else { + return false; + } + } else { + publish(0); + return true; + } + } + +private: + friend class offline_messages; + + std::string topic_; + std::vector payload_; + pub::opts pubopts_; + properties props_; + std::shared_ptr tim_message_expiry_; +}; + +class offline_messages { +public: + template + void send_until_fail(Epsp& epsp, protocol_version ver) { + epsp.dispatch([this, epsp, ver] { + auto& idx = messages_.get(); + while (!idx.empty()) { + auto it = idx.begin(); + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& m = const_cast(*it); + if (m.send(epsp, ver)) { + idx.pop_front(); + } else { + break; + } + } + }); + } + + void clear() { messages_.clear(); } + + bool empty() const { return messages_.empty(); } + + void push_back(as::io_context& timer_ioc, + std::string pub_topic, + std::vector payload, + pub::opts pubopts, + properties props) { + std::optional message_expiry_interval; + + for (auto const& prop : props) { + prop.visit(overload{ [&](property::message_expiry_interval const& p) { + message_expiry_interval.emplace(std::chrono::seconds(p.val())); + }, + [](auto const&) {} }); + } + + std::shared_ptr tim_message_expiry; + if (message_expiry_interval) { + tim_message_expiry = std::make_shared(timer_ioc, *message_expiry_interval); + tim_message_expiry->async_wait( + [this, wp = std::weak_ptr(tim_message_expiry)](error_code ec) mutable { + if (auto sp = wp.lock()) { + if (!ec) { + messages_.get().erase(sp); + } + } + }); + } + + auto& seq_idx = messages_.get(); + seq_idx.emplace_back(force_move(pub_topic), force_move(payload), pubopts, force_move(props), + force_move(tim_message_expiry)); + } + +private: + using mi_offline_message = mi::multi_index_container< + offline_message, + mi::indexed_by >, + mi::ordered_non_unique, mi::key<&offline_message::tim_message_expiry_> > > >; + + mi_offline_message messages_; +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_OFFLINE_MESSAGE_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/readme.md b/exes/mqtt-bridge/tests/inc/broker/readme.md new file mode 100644 index 0000000000..650a794545 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/readme.md @@ -0,0 +1,2 @@ +Copied from b859411b2d6ec153fa2736641ee9f238c02ede77 +See discussion: https://github.com/redboltz/async_mqtt/issues/314 diff --git a/exes/mqtt-bridge/tests/inc/broker/retain_type.hpp b/exes/mqtt-bridge/tests/inc/broker/retain_type.hpp new file mode 100644 index 0000000000..b350889f19 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/retain_type.hpp @@ -0,0 +1,45 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_RETAIN_TYPE_HPP) +#define ASYNC_MQTT_BROKER_RETAIN_TYPE_HPP + +#include + +#include +#include +#include + +namespace async_mqtt { + +// A collection of messages that have been retained in +// case clients add a new subscription to the associated topics. +struct retain_type { + retain_type(std::string topic, + std::vector payload, + properties props, + qos qos_value, + std::shared_ptr tim_message_expiry = std::shared_ptr()) + : topic(force_move(topic)), props(force_move(props)), qos_value(qos_value), + tim_message_expiry(force_move(tim_message_expiry)) { + auto it = std::cbegin(payload); + auto end = std::cend(payload); + for (; it != end; ++it) { + this->payload.emplace_back(*it); + } + } + + std::string topic; + std::vector payload; + properties props; + qos qos_value; + std::shared_ptr tim_message_expiry; +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_RETAIN_TYPE_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/retained_messages.hpp b/exes/mqtt-bridge/tests/inc/broker/retained_messages.hpp new file mode 100644 index 0000000000..c36a8e0089 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/retained_messages.hpp @@ -0,0 +1,22 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_RETAINED_MESSAGES_HPP) +#define ASYNC_MQTT_BROKER_RETAINED_MESSAGES_HPP + +#include // reference_wrapper + +#include +#include + +namespace async_mqtt { + +using retained_messages = retained_topic_map; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_RETAINED_MESSAGES_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/retained_topic_map.hpp b/exes/mqtt-bridge/tests/inc/broker/retained_topic_map.hpp new file mode 100644 index 0000000000..fd19bbee48 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/retained_topic_map.hpp @@ -0,0 +1,344 @@ +#pragma once +// Copyright Wouter van Kleunen 2019 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_RETAINED_TOPIC_MAP_HPP) +#define ASYNC_MQTT_BROKER_RETAINED_TOPIC_MAP_HPP + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +namespace async_mqtt { + +namespace mi = boost::multi_index; + +template +class retained_topic_map { + // Exceptions used + static void throw_max_stored_topics() { throw std::overflow_error("Retained map maximum number of topics reached"); } + static void throw_no_wildcards_allowed() { + throw std::runtime_error("Retained map no wildcards allowed in retained topic name"); + } + + using node_id_type = std::size_t; + + static constexpr node_id_type root_parent_id = 0; + static constexpr node_id_type root_node_id = 1; + static constexpr node_id_type max_node_id = std::numeric_limits::max(); + + struct path_entry { + // Increase the count for this node + void increase_count() { + if (count == max_count) { + throw_max_stored_topics(); + } + + ++count; + } + + // Decrease the count for this node + void decrease_count() { + BOOST_ASSERT(count >= count); + --count; + } + + std::optional value; + + path_entry(node_id_type parent_id, std::string_view name, node_id_type id) + : parent_id{ parent_id }, name{ name }, id{ id } {} + + std::string_view name_as_string_view() const { return name; } + + node_id_type parent_id; + std::string name; + + node_id_type id; + + std::size_t count = 1; + static constexpr std::size_t max_count = std::numeric_limits::max(); + }; + + struct wildcard_index_tag {}; + struct direct_index_tag {}; + + // allow for two indices on retained topics + using path_entry_set = mi::multi_index_container< + path_entry, + mi::indexed_by< + // index required for direct child access + mi::hashed_unique, mi::key<&path_entry::parent_id, &path_entry::name_as_string_view> >, + // index required for wildcard processing + mi::ordered_non_unique, mi::key<&path_entry::parent_id> > > >; + + using direct_const_iterator = typename path_entry_set::template index::type::const_iterator; + using wildcard_const_iterator = typename path_entry_set::template index::type::const_iterator; + + path_entry_set map; + size_t map_size; + node_id_type next_node_id; + + direct_const_iterator root; + + direct_const_iterator create_topic(std::string_view topic) { + direct_const_iterator parent = root; + + topic_filter_tokenizer(topic, [this, &parent](std::string_view t) { + if (t == "+" || t == "#") { + throw_no_wildcards_allowed(); + } + + node_id_type parent_id = parent->id; + + auto& direct_index = map.template get(); + direct_const_iterator entry = direct_index.find(std::make_tuple(parent_id, t)); + + if (entry == direct_index.end()) { + entry = map.insert(path_entry(parent->id, t, next_node_id++)).first; + if (next_node_id == max_node_id) { + throw_max_stored_topics(); + } + } else { + direct_index.modify(entry, [](path_entry& entry) { entry.increase_count(); }); + } + + parent = entry; + return true; + }); + + return parent; + } + + std::vector find_topic(std::string_view topic) { + std::vector path; + direct_const_iterator parent = root; + + topic_filter_tokenizer(topic, [this, &parent, &path](std::string_view t) { + auto const& direct_index = map.template get(); + auto entry = direct_index.find(std::make_tuple(parent->id, t)); + + if (entry == direct_index.end()) { + path = std::vector(); + return false; + } + + path.push_back(entry); + parent = entry; + return true; + }); + + return path; + } + + // Match all underlying topics when a hash entry is matched + // perform a breadth-first iteration over all items in the tree below + template + void match_hash_entries(node_id_type parent, Output&& callback, bool ignore_system) const { + std::deque entries; + entries.push_back(parent); + std::deque new_entries; + + auto const& wildcard_index = map.template get(); + + while (!entries.empty()) { + new_entries.resize(0); + + for (auto root : entries) { + // Find all entries below this node + for (auto i = wildcard_index.lower_bound(root); i != wildcard_index.end() && i->parent_id == root; ++i) { + // Should we ignore system matches + if (!ignore_system || i->name.empty() || i->name[0] != '$') { + if (i->value) { + callback(*i->value); + } + + new_entries.push_back(i->id); + } + } + } + + // Ignore system only on first level + ignore_system = false; + std::swap(entries, new_entries); + } + } + + // Find all topics that match the specified topic filter + template + void find_match(std::string_view topic_filter, Output&& callback) const { + std::deque entries; + entries.push_back(root); + + std::deque new_entries; + topic_filter_tokenizer(topic_filter, [this, &entries, &new_entries, &callback](std::string_view t) { + auto const& direct_index = map.template get(); + auto const& wildcard_index = map.template get(); + new_entries.resize(0); + + for (auto const& entry : entries) { + node_id_type parent = entry->id; + + if (t == std::string_view("+")) { + for (auto i = wildcard_index.lower_bound(parent); i != wildcard_index.end() && i->parent_id == parent; ++i) { + if (parent != root_node_id || i->name.empty() || i->name[0] != '$') { + new_entries.push_back(map.template project(i)); + } else { + break; + } + } + } else if (t == std::string_view("#")) { + match_hash_entries(parent, callback, parent == root_node_id); + return false; + } else { + direct_const_iterator i = direct_index.find(std::make_tuple(parent, t)); + if (i != direct_index.end()) { + new_entries.push_back(i); + } + } + } + + std::swap(new_entries, entries); + return !entries.empty(); + }); + + for (auto& entry : entries) { + if (entry->value) { + callback(*entry->value); + } + } + } + + // Remove a value at the specified topic + size_t erase_topic(std::string_view topic) { + auto path = find_topic(topic); + + // Reset the value if there is actually something stored + if (!path.empty() && path.back()->value) { + auto& direct_index = map.template get(); + direct_index.modify(path.back(), [](path_entry& entry) { entry.value = std::nullopt; }); + + // Do iterators stay valid when erasing ? I think they do ? + for (auto entry : path) { + direct_index.modify(entry, [](path_entry& entry) { entry.decrease_count(); }); + + if (entry->count == 0) { + map.erase(entry); + } + } + + return 1; + } + + return 0; + } + + // Increase the number of topics for this path + void increase_topics(std::vector const& path) { + auto& direct_index = map.template get(); + + for (auto& i : path) { + direct_index.modify(i, [](path_entry& entry) { entry.increase_count(); }); + } + } + + // Increase the map size (total number of topics stored) + void increase_map_size() { + if (map_size == std::numeric_limits::max()) { + throw_max_stored_topics(); + } + + ++map_size; + } + + // Decrease the map size (total number of topics stored) + void decrease_map_size(size_t count) { + BOOST_ASSERT(map_size >= count); + map_size -= count; + } + + void init_map() { + map_size = 0; + // Create the root node + root = map.insert(path_entry(root_parent_id, "", root_node_id)).first; + next_node_id = root_node_id + 1; + } + +public: + retained_topic_map() { init_map(); } + + // Insert a value at the specified topic + template + std::size_t insert_or_assign(std::string_view topic, V&& value) { + auto& direct_index = map.template get(); + auto path = this->find_topic(topic); + + if (path.empty()) { + auto new_topic = this->create_topic(topic); + direct_index.modify(new_topic, [&value](path_entry& entry) mutable { entry.value.emplace(std::forward(value)); }); + increase_map_size(); + return 1; + } + + if (!path.back()->value) { + this->increase_topics(path); + direct_index.modify(path.back(), [&value](path_entry& entry) mutable { entry.value.emplace(std::forward(value)); }); + increase_map_size(); + return 1; + } + + direct_index.modify(path.back(), [&value](path_entry& entry) mutable { entry.value.emplace(std::forward(value)); }); + + return 0; + } + + // Find all stored topics that math the specified topic_filter + template + void find(std::string_view topic_filter, Output&& callback) const { + find_match(topic_filter, std::forward(callback)); + } + + // Remove a stored value at the specified topic + std::size_t erase(std::string_view topic) { + auto result = erase_topic(topic); + decrease_map_size(result); + return result; + } + + // Get the number of entries stored in the map + std::size_t size() const { return map_size; } + + // Get the number of entries in the map (for debugging purpose only) + std::size_t internal_size() const { return map.size(); } + + // Clear all topics + void clear() { + map.clear(); + init_map(); + } + + // Dump debug information + template + void dump(Output& out) { + auto const& direct_index = map.template get(); + for (auto const& i : direct_index) { + out << i.parent_id << " " << i.name << " " << (i.value ? "init" : "-") << " " << i.count << '\n'; + } + } +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_RETAINED_TOPIC_MAP_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/security.hpp b/exes/mqtt-bridge/tests/inc/broker/security.hpp new file mode 100644 index 0000000000..815ec0dd99 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/security.hpp @@ -0,0 +1,694 @@ +#pragma once +// Copyright Wouter van Kleunen 2021 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_SECURITY_HPP) +#define ASYNC_MQTT_BROKER_SECURITY_HPP + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-conversion" +#endif // defined(__GNUC__) + +// from https://github.com/okdshin/PicoSHA2 +// MIT license +#include + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif // defined(__GNUC__) + +#if ASYNC_MQTT_USE_TLS +#include +#endif + +#include +#include +#include + +namespace async_mqtt { + +/** Remove comments from a JSON file (comments start with # and are not inside ' ' or " ") */ +inline std::string json_remove_comments(std::istream& input) { + bool inside_comment = false; + bool inside_single_quote = false; + bool inside_double_quote = false; + + std::ostringstream result; + + while (true) { + char c; + if (input.get(c).eof()) + break; + + if (!inside_double_quote && !inside_single_quote && c == '#') + inside_comment = true; + if (!inside_comment && !inside_double_quote && c == '\'') + inside_single_quote = !inside_single_quote; + if (!inside_comment && !inside_single_quote && c == '"') + inside_double_quote = !inside_double_quote; + if (c == '\n') + inside_comment = false; + + if (!inside_comment) + result << c; + } + + return result.str(); +} + +struct security { + static constexpr char const* any_group_name = "@any"; + + struct authentication { + enum class method { sha256, plain_password, client_cert, anonymous, unauthenticated }; + + authentication(method auth_method = method::sha256, + std::optional const& digest = std::nullopt, + std::string const& salt = std::string()) + : auth_method(auth_method), digest(digest), salt(salt) {} + + method auth_method; + std::optional digest; + std::string salt; + + std::vector groups; + }; + + struct authorization { + enum class type { deny, allow, none }; + + authorization(std::string_view topic, std::size_t rule_nr) + : topic(topic), rule_nr(rule_nr), sub_type(type::none), pub_type(type::none) {} + + std::vector topic_tokens; + + std::string topic; + std::size_t rule_nr; + + type sub_type; + std::set sub; + + type pub_type; + std::set pub; + }; + + struct group { + std::string name; + std::vector members; + }; + + /** Return username of anonymous user */ + std::optional const& login_anonymous() const { return anonymous; } + + /** Return username of unauthorized user */ + std::optional const& login_unauthenticated() const { return unauthenticated; } + + template + static std::string to_hex(T start, T end) { + std::string result; + boost::algorithm::hex(start, end, std::back_inserter(result)); + return result; + } + + static std::string sha256hash(std::string_view message) { + std::vector hash(picosha2::k_digest_size); + picosha2::hash256(message.begin(), message.end(), hash.begin(), hash.end()); + return picosha2::bytes_to_hex_string(hash.begin(), hash.end()); + } + + bool login_cert(std::string_view username) const { + auto i = authentication_.find(std::string(username)); + return i != authentication_.end() && i->second.auth_method == security::authentication::method::client_cert; + } + + std::optional login(std::string_view username, std::string_view password) const { + auto i = authentication_.find(std::string(username)); + if (i != authentication_.end() && i->second.auth_method == security::authentication::method::sha256) { + return [&]() -> std::optional { + if (boost::iequals(i->second.digest.value(), sha256hash(i->second.salt + std::string(password)))) { + return std::string(username); + } else { + return std::nullopt; + } + }(); + + } else if (i != authentication_.end() && i->second.auth_method == security::authentication::method::plain_password) { + return [&]() -> std::optional { + if (i->second.digest.value() == password) { + return std::string(username); + } else { + return std::nullopt; + } + }(); + } + return std::nullopt; + } + + static authorization::type get_auth_type(std::string_view type) { + if (type == "allow") + return authorization::type::allow; + if (type == "deny") + return authorization::type::deny; + throw std::runtime_error("An invalid authorization type was specified: " + std::string(type)); + } + + static bool is_valid_group_name(std::string_view name) { + return !name.empty() && name[0] == '@'; // TODO: validate utf-8 + } + + static bool is_valid_user_name(std::string_view name) { + return !name.empty() && name[0] != '@'; // TODO: validate utf-8 + } + + std::size_t get_next_rule_nr() const { + std::size_t rule_nr = 0; + for (auto const& i : authorization_) { + rule_nr = std::max(rule_nr, i.rule_nr); + } + return rule_nr + 1; + } + + void default_config() { + char const* username = "anonymous"; + authentication login(authentication::method::anonymous); + authentication_.insert({ username, login }); + anonymous = username; + + char const* topic = "#"; + authorization auth(topic, get_next_rule_nr()); + auth.topic_tokens = get_topic_filter_tokens("#"); + auth.sub_type = authorization::type::allow; + auth.sub.insert(username); + auth.pub_type = authorization::type::allow; + auth.pub.insert(username); + authorization_.push_back(auth); + + groups_.insert({ std::string(any_group_name), group() }); + + validate(); + } + + std::size_t add_auth(std::string const& topic_filter, + std::set const& pub, + authorization::type auth_pub_type, + std::set const& sub, + authorization::type auth_sub_type) { + for (auto const& j : pub) { + if (!is_valid_user_name(j) && !is_valid_group_name(j)) { + throw std::runtime_error("An invalid username or groupname was specified for the authorization: " + j); + } + validate_entry("topic " + topic_filter, j); + } + + for (auto const& j : sub) { + if (!is_valid_user_name(j) && !is_valid_group_name(j)) { + throw std::runtime_error("An invalid username or groupname was specified for the authorization: " + j); + } + validate_entry("topic " + topic_filter, j); + } + + std::size_t rule_nr = get_next_rule_nr(); + authorization auth(topic_filter, rule_nr); + auth.topic_tokens = get_topic_filter_tokens(topic_filter); + auth.pub = pub; + auth.pub_type = auth_pub_type; + auth.sub = sub; + auth.sub_type = auth_sub_type; + + for (auto const& j : sub) { + auth_sub_map.insert_or_assign(topic_filter, j, std::make_pair(auth_sub_type, rule_nr)); + } + for (auto const& j : pub) { + auth_pub_map.insert_or_assign(topic_filter, j, std::make_pair(auth_pub_type, rule_nr)); + } + + authorization_.push_back(auth); + return rule_nr; + } + + void remove_auth(std::size_t rule_nr) { + for (auto i = authorization_.begin(); i != authorization_.end(); ++i) { + if (i->rule_nr == rule_nr) { + for (auto const& j : i->sub) { + auth_sub_map.erase(i->topic, j); + } + for (auto const& j : i->pub) { + auth_pub_map.erase(i->topic, j); + } + + authorization_.erase(i); + return; + } + } + } + + void load_json(std::istream& input) { + // Create a root + boost::property_tree::ptree root; + + std::istringstream input_without_comments(json_remove_comments(input)); + boost::property_tree::read_json(input_without_comments, root); + + groups_.insert({ std::string(any_group_name), group() }); + + for (auto const& i : root.get_child("authentication")) { + std::string name = i.second.get("name"); + if (!is_valid_user_name(name)) { + throw std::runtime_error("An invalid username was specified: " + name); + } + + std::string method = i.second.get("method"); + + if (method == "sha256") { + std::string digest = i.second.get("digest"); + std::string salt = i.second.get("salt", ""); + + authentication auth(authentication::method::sha256, digest, salt); + authentication_.insert({ name, auth }); + } else if (method == "plain_password") { + std::string digest = i.second.get("password"); + + authentication auth(authentication::method::plain_password, digest); + authentication_.insert({ name, auth }); + } else if (method == "client_cert") { + authentication auth(authentication::method::client_cert); + authentication_.insert({ name, auth }); + } else if (method == "anonymous") { + if (anonymous) { + throw std::runtime_error("Only a single anonymous user can be configured, anonymous user: " + *anonymous); + } + anonymous = name; + + authentication auth(authentication::method::anonymous); + authentication_.insert({ name, auth }); + } else if (method == "unauthenticated") { + if (unauthenticated) { + throw std::runtime_error("Only a single unauthenticated user can be configured, unauthenticated user: " + + *unauthenticated); + } + unauthenticated = name; + + authentication auth(authentication::method::unauthenticated); + authentication_.insert({ name, auth }); + } else { + throw std::runtime_error("An invalid method was specified: " + method); + } + } + if (root.get_child_optional("group")) { + for (auto const& i : root.get_child("group")) { + std::string name = i.second.get("name"); + if (!is_valid_group_name(name)) { + throw std::runtime_error("An invalid group name was specified: " + name); + } + + group group; + if (i.second.get_child_optional("members")) { + for (auto const& j : i.second.get_child("members")) { + auto username = j.second.get_value(); + if (!is_valid_user_name(username)) { + throw std::runtime_error("An invalid user name was specified: " + username); + } + group.members.push_back(username); + } + } + + groups_.insert({ name, group }); + } + } + + for (auto const& i : root.get_child("authorization")) { + std::string name = i.second.get("topic"); + if (!validate_topic_filter(name)) { + throw std::runtime_error("An invalid topic filter was specified: " + name); + } + + authorization auth(name, get_next_rule_nr()); + auth.topic_tokens = get_topic_filter_tokens(name); + + if (i.second.get_child_optional("allow")) { + auto& allow = i.second.get_child("allow"); + if (allow.get_child_optional("sub")) { + for (auto const& j : allow.get_child("sub")) { + auth.sub.insert(j.second.get_value()); + } + auth.sub_type = authorization::type::allow; + } + + if (allow.get_child_optional("pub")) { + for (auto const& j : allow.get_child("pub")) { + auth.pub.insert(j.second.get_value()); + } + auth.pub_type = authorization::type::allow; + } + } + + if (i.second.get_child_optional("deny")) { + auto& deny = i.second.get_child("deny"); + if (deny.get_child_optional("sub")) { + for (auto const& j : deny.get_child("sub")) { + auth.sub.insert(j.second.get_value()); + } + auth.sub_type = authorization::type::deny; + } + + if (deny.get_child_optional("pub")) { + for (auto const& j : deny.get_child("pub")) { + auth.pub.insert(j.second.get_value()); + } + auth.pub_type = authorization::type::deny; + } + } + authorization_.push_back(auth); + } + + validate(); + } + + template + void get_auth_sub_by_user(std::string_view username, T&& callback) const { + std::set username_and_groups; + username_and_groups.insert(std::string(username)); + + for (auto const& i : groups_) { + if (i.first == any_group_name || + std::find(i.second.members.begin(), i.second.members.end(), username) != i.second.members.end()) { + username_and_groups.insert(i.first); + } + } + + for (auto const& i : authorization_) { + if (i.sub_type != authorization::type::none) { + bool sets_intersect = false; + auto store_intersect = [&sets_intersect](std::string const&) mutable { sets_intersect = true; }; + + std::set_intersection(i.sub.begin(), i.sub.end(), username_and_groups.begin(), username_and_groups.end(), + boost::make_function_output_iterator(std::ref(store_intersect))); + + if (sets_intersect) { + std::forward(callback)(i); + } + } + } + } + + authorization::type auth_pub(std::string_view topic, std::string_view username) const { + authorization::type result_type = authorization::type::deny; + + std::set username_and_groups; + username_and_groups.insert(std::string(username)); + + for (auto const& i : groups_) { + if (i.first == any_group_name || + std::find(i.second.members.begin(), i.second.members.end(), username) != i.second.members.end()) { + username_and_groups.insert(i.first); + } + } + + std::size_t priority = 0; + auth_pub_map.find(topic, [&](std::string const& allowed_username, std::pair entry) { + if (username_and_groups.find(allowed_username) != username_and_groups.end()) { + if (entry.second >= priority) { + result_type = entry.first; + priority = entry.second; + } + } + }); + + return result_type; + } + + std::map auth_sub(std::string_view topic) const { + std::map result; + std::map priorities; + auth_sub_map.find(topic, [&](std::string const& allowed_username, std::pair entry) { + auto rit = result.find(allowed_username); + if (rit == result.end()) { + result.emplace(allowed_username, entry.first); + priorities.emplace(allowed_username, entry.second); + } else { + auto pit = priorities.find(allowed_username); + BOOST_ASSERT(pit != priorities.end()); + if (pit->second <= entry.second) { + pit->second = entry.second; + rit->second = entry.first; + } + } + }); + + return result; + } + + authorization::type auth_sub_user(std::map const& result, + std::string const& username) const { + auto i = result.find(username); + if (i != result.end()) + return i->second; + + for (auto const& i : groups_) { + if (i.first == any_group_name || + std::find(i.second.members.begin(), i.second.members.end(), username) != i.second.members.end()) { + auto j = result.find(i.first); + if (j != result.end()) + return j->second; + } + } + + return authorization::type::deny; + } + + static bool is_hash(std::string_view level) { return level == "#"; } + static bool is_plus(std::string_view level) { return level == "+"; } + static bool is_literal(std::string_view level) { return !is_hash(level) && !is_plus(level); } + + static std::optional is_subscribe_allowed(std::vector const& authorized_filter, + std::string_view subscription_filter) { + std::optional result; + auto append_result = [&result](std::string_view token) { + if (result) { + *result += topic_filter_separator; + result->append(token.data(), token.size()); + } else { + result = std::string(token); + } + }; + + auto filter_begin = authorized_filter.begin(); + + auto subscription_begin = subscription_filter.begin(); + auto subscription_next = topic_filter_tokenizer_next(subscription_begin, subscription_filter.end()); + + while (true) { + if (filter_begin == authorized_filter.end()) { + return std::nullopt; + } + + auto auth = *filter_begin; + ++filter_begin; + + if (is_hash(auth)) { + append_result(make_string_view(subscription_begin, subscription_filter.end())); + return result; + } + + auto sub = make_string_view(subscription_begin, subscription_next); + + if (is_hash(sub)) { + append_result(auth); + + while (filter_begin < authorized_filter.end()) { + append_result(*filter_begin); + ++filter_begin; + } + + return result; + } + + if (is_plus(auth)) { + append_result(sub); + } else if (is_plus(sub)) { + append_result(auth); + } else { + if (auth != sub) { + return std::nullopt; + } + + append_result(auth); + } + + if (subscription_next == subscription_filter.end()) + break; + subscription_begin = std::next(subscription_next); + subscription_next = topic_filter_tokenizer_next(subscription_begin, subscription_filter.end()); + } + + if (filter_begin < authorized_filter.end()) { + return std::nullopt; + } + + return result; + } + + static bool is_subscribe_denied(std::vector const& deny_filter, std::string_view subscription_filter) { + bool result = true; + auto filter_begin = deny_filter.begin(); + + auto tokens_count = topic_filter_tokenizer(subscription_filter, [&](auto sub) { + if (filter_begin == deny_filter.end()) { + result = false; + return false; + }; + + std::string deny = *filter_begin; + ++filter_begin; + + if (deny != sub) { + if (is_hash(deny)) { + result = true; + return false; + } + + if (is_hash(sub)) { + result = false; + return false; + } + + if (is_plus(deny)) { + result = true; + return true; + } + + result = false; + return false; + } + + return true; + }); + + return result && (tokens_count == deny_filter.size()); + } + + std::vector get_auth_sub_topics(std::string_view username, std::string_view topic_filter) const { + std::vector auth_topics; + get_auth_sub_by_user(username, [&](authorization const& i) { + if (i.sub_type == authorization::type::allow) { + auto entry = is_subscribe_allowed(i.topic_tokens, topic_filter); + if (entry) { + auth_topics.push_back(*entry); + } + } else { + for (auto j = auth_topics.begin(); j != auth_topics.end();) { + if (is_subscribe_denied(i.topic_tokens, topic_filter)) { + j = auth_topics.erase(j); + } else { + ++j; + } + } + } + }); + return auth_topics; + } + + /** + * @brief Determine if user is allowed to subscribe to the specified topic filter + * @param username The username to check + * @param topic_filter Topic filter the user would like to subscribe to + * @return true if the user is authorized + */ + bool is_subscribe_authorized(std::string_view username, std::string_view topic_filter) const { + return !get_auth_sub_topics(username, topic_filter).empty(); + } + + // Get the individual path elements of the topic filter + static std::vector get_topic_filter_tokens(std::string_view topic_filter) { + std::vector result; + topic_filter_tokenizer(topic_filter, [&result](auto str) { + result.push_back(std::string(str)); + return true; + }); + + return result; + } + + std::map authentication_; + std::map groups_; + + std::vector authorization_; + + std::optional anonymous; + std::optional unauthenticated; + + using auth_map_type = multiple_subscription_map>; + auth_map_type auth_pub_map; + auth_map_type auth_sub_map; + +private: + void validate_entry(std::string const& context, std::string const& name) const { + if (is_valid_group_name(name) && groups_.find(name) == groups_.end()) { + throw std::runtime_error("An invalid group name was specified for " + context + ": " + name); + } + if (is_valid_user_name(name) && authentication_.find(name) == authentication_.end()) { + throw std::runtime_error("An invalid username name was specified for " + context + ": " + name); + } + } + + void validate() { + for (auto const& i : groups_) { + for (auto const& j : i.second.members) { + auto iter = authentication_.find(j); + if (is_valid_user_name(j) && iter == authentication_.end()) + throw std::runtime_error("An invalid username name was specified for group " + i.first + ": " + j); + } + } + + std::string unsalted; + for (auto const& i : authentication_) { + if (i.second.auth_method == authentication::method::sha256 && i.second.salt.empty()) { + if (!unsalted.empty()) + unsalted += ", "; + unsalted += i.first; + } + } + + if (!unsalted.empty()) { + ASYNC_MQTT_LOG("mqtt_broker", warning) << "The following users have no salt specified: " << unsalted; + } + + for (auto const& i : authorization_) { + for (auto const& j : i.sub) { + validate_entry("topic " + i.topic, j); + + if (is_valid_user_name(j) || is_valid_group_name(j)) { + auth_sub_map.insert_or_assign(i.topic, j, std::make_pair(i.sub_type, i.rule_nr)); + } + } + for (auto const& j : i.pub) { + validate_entry("topic " + i.topic, j); + + if (is_valid_user_name(j) || is_valid_group_name(j)) { + auth_pub_map.insert_or_assign(i.topic, j, std::make_pair(i.pub_type, i.rule_nr)); + } + } + } + } +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_SECURITY_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/session_state.hpp b/exes/mqtt-bridge/tests/inc/broker/session_state.hpp new file mode 100644 index 0000000000..1e61bb7544 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/session_state.hpp @@ -0,0 +1,609 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_SESSION_STATE_HPP) +#define ASYNC_MQTT_BROKER_SESSION_STATE_HPP + +#include +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace async_mqtt { + +namespace as = boost::asio; +namespace mi = boost::multi_index; + +template +class session_states; + +/** + * http://docs.oasis-open.org/mqtt/mqtt/v5.0/cs02/mqtt-v5.0-cs02.html#_Session_State + * + * 4.1 Session State + * In order to implement QoS 1 and QoS 2 protocol flows the Client and Server need to associate state with the Client + * Identifier, this is referred to as the Session State. The Server also stores the subscriptions as part of the Session + * State. The session can continue across a sequence of Network Connections. It lasts as long as the latest Network + * Connection plus the Session Expiry Interval. The Session State in the Server consists of: · The existence of a Session, + * even if the rest of the Session State is empty. · The Clients subscriptions, including any Subscription Identifiers. · QoS + * 1 and QoS 2 messages which have been sent to the Client, but have not been completely acknowledged. · QoS 1 and QoS 2 + * messages pending transmission to the Client and OPTIONALLY QoS 0 messages pending transmission to the Client. · QoS 2 + * messages which have been received from the Client, but have not been completely acknowledged. · The Will Message and the + * Will Delay Interval · If the Session is currently not connected, the time at which the Session will end and Session State + * will be discarded. + * + * Retained messages do not form part of the Session State in the Server, they are not deleted as a result of a Session + * ending. + */ +template +struct session_state : std::enable_shared_from_this> { + using this_type = session_state; + using epsp_type = Sp; + using epwp_type = typename epsp_type::weak_type; + using will_sender_type = std::function const& source_ss, + std::string topic, + std::vector payload, + pub::opts pubopts, + properties props)>; + + static std::shared_ptr> create( + as::io_context& timer_ioc, + mutex& mtx_subs_map, + sub_con_map& subs_map, + shared_target& shared_targets, + epsp_type epsp, + std::string client_id, + std::string const& username, + std::optional will, + will_sender_type will_sender, + bool clean_start, + std::optional will_expiry_interval, + std::optional session_expiry_interval) { + struct impl : session_state { + impl(as::io_context& timer_ioc, + mutex& mtx_subs_map, + sub_con_map& subs_map, + shared_target& shared_targets, + epsp_type epsp, + std::string client_id, + std::string const& username, + will_sender_type will_sender, + bool clean_start, + std::optional session_expiry_interval) + : session_state{ timer_ioc, mtx_subs_map, + subs_map, shared_targets, + force_move(epsp), force_move(client_id), + username, force_move(will_sender), + clean_start, force_move(session_expiry_interval) } {} + }; + std::shared_ptr> sssp = + std::make_shared(timer_ioc, mtx_subs_map, subs_map, shared_targets, force_move(epsp), force_move(client_id), + username, force_move(will_sender), clean_start, force_move(session_expiry_interval)); + sssp->update_will(timer_ioc, will, will_expiry_interval); + return sssp; + } + + ~session_state() { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "session destroy"; + send_will_impl(); + clean(); + } + + template + void become_offline(epsp_type epsp, SessionExpireHandler&& session_expire_handler) { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "store inflight message"; + auto stored = epsp.get_stored_packets(); + for (auto& store : stored) { + std::shared_ptr tim_message_expiry; + store.visit(overload{ [&](v5::publish_packet const& p) { + for (auto const& prop : p.props()) { + prop.visit(overload{ + [&](property::message_expiry_interval const& v) { + tim_message_expiry = + std::make_shared(timer_ioc_, std::chrono::seconds(v.val())); + tim_message_expiry->async_wait( + [this, wp = std::weak_ptr(tim_message_expiry)](error_code ec) { + if (auto sp = wp.lock()) { + if (!ec) { + erase_inflight_message_by_expiry(sp); + } + } + }); + }, + [](auto const&) {} }); + } + }, + [&](auto const&) {} }); + + insert_inflight_message(force_move(store), force_move(tim_message_expiry)); + } + + qos2_publish_handled_ = epsp.get_qos2_publish_handled_pids(); + + if (session_expiry_interval_ && *session_expiry_interval_ != std::chrono::seconds(session_never_expire)) { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "session expiry interval timer set"; + + tim_session_expiry_ = std::make_shared(timer_ioc_, *session_expiry_interval_); + tim_session_expiry_->async_wait( + [this, wp = std::weak_ptr(tim_session_expiry_), + session_expire_handler = std::forward(session_expire_handler)](error_code ec) { + if (auto sp = wp.lock()) { + if (!ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, this) << "session expired"; + session_expire_handler(sp); + } + } + }); + } + } + + void renew(epsp_type epsp, + std::optional will, + bool clean_start, + std::optional will_expiry_interval, + std::optional session_expiry_interval) { + clean(); + epwp_ = epsp; + auto version = epsp.get_protocol_version(); + if (version == protocol_version::v3_1_1) { + remain_after_close_ = !clean_start; + } else { + BOOST_ASSERT(version == protocol_version::v5); + remain_after_close_ = + session_expiry_interval_ && *session_expiry_interval_ != std::chrono::steady_clock::duration::zero(); + } + update_will(timer_ioc_, force_move(will), will_expiry_interval); + session_expiry_interval_ = force_move(session_expiry_interval); + } + + void publish(epsp_type& epsp, + as::io_context& timer_ioc, + std::string pub_topic, + std::vector payload, + pub::opts pubopts, + properties props) { + auto send_publish = [this, epsp, pub_topic, payload = payload, pubopts, props, + wp = this->weak_from_this()](packet_id_type pid) mutable { + if (auto sp = wp.lock()) { + switch (version_) { + case protocol_version::v3_1_1: + epsp.async_send(v3_1_1::publish_packet{ pid, force_move(pub_topic), force_move(payload), pubopts }, + [this, epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, this) + << "epsp:" << epsp.get_address() << " " << ec.message(); + } + }); + break; + case protocol_version::v5: + epsp.async_send( + v5::publish_packet{ pid, force_move(pub_topic), force_move(payload), pubopts, force_move(props) }, + [this, epsp](error_code const& ec) { + if (ec) { + ASYNC_MQTT_LOG("mqtt_broker", info) + << ASYNC_MQTT_ADD_VALUE(address, this) << "epsp:" << epsp.get_address() << " " << ec.message(); + } + }); + break; + default: + BOOST_ASSERT(false); + break; + } + } + }; + + std::lock_guard g(mtx_offline_messages_); + if (offline_messages_.empty()) { + auto qos_value = pubopts.get_qos(); + if (qos_value == qos::at_least_once || qos_value == qos::exactly_once) { + epsp.async_acquire_unique_packet_id( + [send_publish = force_move(send_publish)](error_code const& ec, auto pid) mutable { + if (!ec) { + send_publish(pid); + return; + } + }); + return; + } else { + send_publish(0); + return; + } + } + + // offline_messages_ is not empty or packet_id_exhausted + offline_messages_.push_back(timer_ioc, force_move(pub_topic), force_move(payload), pubopts, force_move(props)); + } + + void deliver(as::io_context& timer_ioc, + std::string pub_topic, + std::vector payload, + pub::opts pubopts, + properties props) { + if (auto epsp = lock()) { + publish(epsp, timer_ioc, force_move(pub_topic), force_move(payload), pubopts, force_move(props)); + } else { + std::lock_guard g(mtx_offline_messages_); + offline_messages_.push_back(timer_ioc, force_move(pub_topic), force_move(payload), pubopts, force_move(props)); + } + } + + void set_clean_handler(std::function handler) { clean_handler_ = force_move(handler); } + + void clean() { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "clean"; + if (clean_handler_) + clean_handler_(); + if (tim_will_expiry_) + tim_will_expiry_->cancel(); + will_value_ = std::nullopt; + { + std::lock_guard g(mtx_inflight_messages_); + inflight_messages_.clear(); + } + { + std::lock_guard g(mtx_offline_messages_); + offline_messages_.clear(); + } + unsubscribe_all(); + shared_targets_.erase(*this); + tim_will_delay_.cancel(); + + session_expiry_interval_ = std::nullopt; + if (tim_session_expiry_) + tim_session_expiry_->cancel(); + qos2_publish_handled_.clear(); + response_topic_ = std::nullopt; + } + + template + void subscribe(std::string share_name, + std::string topic_filter, + sub::opts subopts, + PublishRetainHandler&& h, + std::optional sid = std::nullopt) { + subscription sub{ *this, share_name, topic_filter, subopts, sid }; + if (!share_name.empty()) { + shared_targets_.insert(share_name, topic_filter, sub, *this); + } + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "subscribe" + << " share_name:" << share_name << " topic_filter:" << topic_filter + << " qos:" << subopts.get_qos(); + + auto handle_ret = [&] { + std::lock_guard g{ mtx_subs_map_ }; + return subs_map_.insert_or_assign(force_move(topic_filter), client_id_, force_move(sub)); + }(); + + auto rh = subopts.get_retain_handling(); + + if (handle_ret.second) { // insert + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "subscription inserted"; + + handles_.insert(handle_ret.first); + if (rh == sub::retain_handling::send || rh == sub::retain_handling::send_only_new_subscription) { + std::forward(h)(); + } + } else { // update + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "subscription updated"; + + if (rh == sub::retain_handling::send) { + std::forward(h)(); + } + } + } + + void unsubscribe(std::string const& share_name, std::string const& topic_filter) { + if (!share_name.empty()) { + shared_targets_.erase(share_name, topic_filter, *this); + } + std::lock_guard g{ mtx_subs_map_ }; + auto handle = subs_map_.lookup(topic_filter); + if (handle) { + handles_.erase(*handle); + subs_map_.erase(*handle, client_id_); + } + } + + void unsubscribe_all() { + { + std::lock_guard g{ mtx_subs_map_ }; + for (auto const& h : handles_) { + subs_map_.erase(h, client_id_); + } + } + handles_.clear(); + } + + void update_will(as::io_context& timer_ioc, + std::optional will, + std::optional will_expiry_interval) { + tim_will_expiry_.reset(); + will_value_ = force_move(will); + + if (will_value_ && will_expiry_interval) { + tim_will_expiry_ = std::make_shared(timer_ioc, *will_expiry_interval); + tim_will_expiry_->async_wait([this, wp = std::weak_ptr(tim_will_expiry_)](error_code ec) { + if (auto sp = wp.lock()) { + if (!ec) { + clear_will(); + } + } + }); + } + } + + void clear_will() { + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "clear will. cid:" << client_id_; + tim_will_expiry_.reset(); + will_value_ = std::nullopt; + } + + void send_will() { + if (!will_value_) + return; + + auto wd_sec = [&]() -> std::size_t { + std::optional wd_opt; + for (auto const& prop : will_value_->props()) { + prop.visit(overload{ [&](property::will_delay_interval const& v) { wd_opt.emplace(v); }, [&](auto const&) {} }); + if (wd_opt) + return wd_opt->val(); + } + return 0; + }(); + + if (remain_after_close_ && wd_sec != 0) { + ASYNC_MQTT_LOG("mqtt_broker", trace) + << ASYNC_MQTT_ADD_VALUE(address, this) << "set will_delay. cid:" << client_id_ << " delay:" << wd_sec; + tim_will_delay_.expires_after(std::chrono::seconds(wd_sec)); + tim_will_delay_.async_wait([this](error_code ec) { + if (!ec) { + send_will_impl(); + } + }); + } else { + send_will_impl(); + } + } + + void insert_inflight_message(store_packet_variant msg, std::shared_ptr tim_message_expiry) { + std::lock_guard g(mtx_inflight_messages_); + inflight_messages_.insert(force_move(msg), force_move(tim_message_expiry)); + } + + void send_inflight_messages() { + if (auto epsp = lock()) { + std::lock_guard g(mtx_inflight_messages_); + inflight_messages_.send_all_messages(epsp); + } + } + + void erase_inflight_message_by_expiry(std::shared_ptr const& sp) { + std::lock_guard g(mtx_inflight_messages_); + auto& idx = inflight_messages_.get(); + auto [b, e] = idx.equal_range(sp); + while (b != e) { + ASYNC_MQTT_LOG("mqtt_broker", info) << "message expired:" << b->packet(); + b = idx.erase(b); + } + } + + std::size_t erase_inflight_message_by_packet_id(packet_id_type packet_id) { + std::lock_guard g(mtx_inflight_messages_); + auto& idx = inflight_messages_.get(); + return idx.erase(packet_id); + } + + void send_all_offline_messages() { + if (auto epsp = lock()) { + std::lock_guard g(mtx_offline_messages_); + offline_messages_.send_until_fail(epsp, get_protocol_version()); + } + } + + void send_offline_messages_by_packet_id_release() { + if (auto epsp = lock()) { + std::lock_guard g(mtx_offline_messages_); + offline_messages_.send_until_fail(epsp, get_protocol_version()); + } + } + + protocol_version get_protocol_version() const { return version_; } + + std::string const& client_id() const { return client_id_; } + + std::string const& get_username() const { return username_; } + + void inherit(epsp_type epsp, + std::optional will, + std::optional will_expiry_interval, + std::optional session_expiry_interval) { + ASYNC_MQTT_LOG("mqtt_broker", info) << ASYNC_MQTT_ADD_VALUE(address, epsp.get_address()) << "inherit"; + + epwp_ = epsp; + auto version = epsp.get_protocol_version(); + if (version == protocol_version::v3_1_1) { + remain_after_close_ = true; + } else { + BOOST_ASSERT(version == protocol_version::v5); + remain_after_close_ = + session_expiry_interval_ && *session_expiry_interval_ != std::chrono::steady_clock::duration::zero(); + } + // for old will + tim_will_delay_.cancel(); + clear_will(); + // for new will + update_will(timer_ioc_, force_move(will), will_expiry_interval); + + session_expiry_interval_ = force_move(session_expiry_interval); + epsp.restore_qos2_publish_handled_pids(qos2_publish_handled_); + } + + epsp_type lock() { return epwp_.lock(); } + + std::optional session_expiry_interval() const { return session_expiry_interval_; } + + void set_response_topic(std::string topic) { response_topic_.emplace(force_move(topic)); } + + std::optional get_response_topic() const { return response_topic_; } + + bool remain_after_close() const { return remain_after_close_; } + +private: + // constructor + session_state(as::io_context& timer_ioc, + mutex& mtx_subs_map, + sub_con_map& subs_map, + shared_target& shared_targets, + epsp_type epsp, + std::string client_id, + std::string const& username, + will_sender_type will_sender, + bool clean_start, + std::optional session_expiry_interval) + : timer_ioc_(timer_ioc), mtx_subs_map_(mtx_subs_map), subs_map_(subs_map), shared_targets_(shared_targets), + epwp_(epsp), version_(epsp.get_protocol_version()), client_id_(force_move(client_id)), username_(username), + session_expiry_interval_(force_move(session_expiry_interval)), tim_will_delay_(timer_ioc_), + will_sender_(force_move(will_sender)), remain_after_close_([&] { + if (version_ == protocol_version::v3_1_1) { + return !clean_start; + } else { + BOOST_ASSERT(version_ == protocol_version::v5); + return session_expiry_interval_ && *session_expiry_interval_ != std::chrono::steady_clock::duration::zero(); + } + }()) {} + + void send_will_impl() { + if (!will_value_) + return; + + ASYNC_MQTT_LOG("mqtt_broker", trace) << ASYNC_MQTT_ADD_VALUE(address, this) << "send will. cid:" << client_id_; + + auto topic = force_move(will_value_->topic()); + auto payload = force_move(will_value_->message_as_buffer()); + auto opts = will_value_->get_qos() | will_value_->get_retain(); + auto props = force_move(will_value_->props()); + properties forward_props; + + for (auto&& prop : props) { + force_move(prop).visit(overload{ [&](property::will_delay_interval&&) { + // WillDelayInterval is not forwarded + }, + [&](auto&& p) { forward_props.push_back(force_move(p)); } }); + } + + will_value_ = std::nullopt; + if (tim_will_expiry_) { + auto d = + std::chrono::duration_cast(tim_will_expiry_->expiry() - std::chrono::steady_clock::now()) + .count(); + if (d < 0) + d = 0; + + bool set = false; + for (auto& prop : forward_props) { + prop.visit(overload{ [&](property::message_expiry_interval& v) { + v = property::message_expiry_interval{ static_cast(d) }; + set = true; + }, + [&](auto&) {} }); + if (set) + break; + } + } + if (will_sender_) { + will_sender_(*this, force_move(topic), std::vector{ force_move(payload) }, opts, force_move(forward_props)); + } + } + +private: + friend class session_states; + + as::io_context& timer_ioc_; + std::shared_ptr tim_will_expiry_; + std::optional will_value_; + + mutex& mtx_subs_map_; + sub_con_map& subs_map_; + shared_target& shared_targets_; + epwp_type epwp_; + protocol_version version_; + std::string client_id_; + + std::string username_; + + std::optional session_expiry_interval_; + std::shared_ptr tim_session_expiry_; + + mutable mutex mtx_inflight_messages_; + inflight_messages inflight_messages_; + + mutable mutex mtx_offline_messages_; + offline_messages offline_messages_; + + using elem_type = typename sub_con_map::handle; + std::set handles_; // to efficient remove + + as::steady_timer tim_will_delay_; + will_sender_type will_sender_; + bool remain_after_close_; + + std::set qos2_publish_handled_; + + std::optional response_topic_; + std::function clean_handler_; +}; + +template +class session_states { + using epsp_type = Sp; + using epwp_type = typename epsp_type::weak_type; + +public: + template + decltype(auto) get() { + return entries_.template get(); + } + + template + decltype(auto) get() const { + return entries_.template get(); + } + + void clear() { entries_.clear(); } + +private: + // The mi_session_online container holds the relevant data about an active connection with the broker. + // It can be queried either with the clientid, or with the shared pointer to the mqtt endpoint object + using ss_type = session_state; + using elem_type = std::shared_ptr; + using mi_session_state = + mi::multi_index_container, mi::key<&ss_type::epwp_>>, + mi::ordered_unique, mi::key<&ss_type::username_, &ss_type::client_id_>>, + mi::ordered_non_unique, mi::key<&ss_type::tim_session_expiry_>>>>; + + mi_session_state entries_; +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_SESSION_STATE_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/session_state_fwd.hpp b/exes/mqtt-bridge/tests/inc/broker/session_state_fwd.hpp new file mode 100644 index 0000000000..4aa97fcbc2 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/session_state_fwd.hpp @@ -0,0 +1,23 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_SESSION_STATE_FWD_HPP) +#define ASYNC_MQTT_BROKER_SESSION_STATE_FWD_HPP + +#include // reference_wrapper + +namespace async_mqtt { + +template +struct session_state; + +template +using session_state_ref = std::reference_wrapper>; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_SESSION_STATE_FWD_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/shared_target.hpp b/exes/mqtt-bridge/tests/inc/broker/shared_target.hpp new file mode 100644 index 0000000000..4645925d24 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/shared_target.hpp @@ -0,0 +1,59 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_SHARED_TARGET_HPP) +#define ASYNC_MQTT_BROKER_SHARED_TARGET_HPP + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace async_mqtt { + +namespace mi = boost::multi_index; + +template +class shared_target { +public: + void insert(std::string share_name, std::string topic_filter, subscription sub, session_state& ss); + void erase(std::string share_name, std::string topic_filter, session_state const& ss); + void erase(session_state const& ss); + std::optional, subscription>> get_target(std::string const& share_name, + std::string const& topic_filter); + +private: + struct entry { + entry(std::string share_name, session_state& ss, std::chrono::time_point tp); + + std::string const& client_id() const; + std::string share_name; + session_state_ref ssr; + std::chrono::time_point tp; + std::map> tf_subs; + }; + + using mi_shared_target = mi::multi_index_container< + entry, + mi::indexed_by, mi::key<&entry::client_id, &entry::share_name>>, + mi::ordered_non_unique, mi::key<&entry::share_name, &entry::tp>>>>; + + mutable mutex mtx_targets_; + mi_shared_target targets_; +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_SHARED_TARGET_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/shared_target_impl.hpp b/exes/mqtt-bridge/tests/inc/broker/shared_target_impl.hpp new file mode 100644 index 0000000000..e2f5fcdcc9 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/shared_target_impl.hpp @@ -0,0 +1,113 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_SHARED_TARGET_IMPL_HPP) +#define ASYNC_MQTT_BROKER_SHARED_TARGET_IMPL_HPP + +#include + +#include +#include + +namespace async_mqtt { + +template +inline void shared_target::insert(std::string share_name, + std::string topic_filter, + subscription sub, + session_state& ss) { + std::lock_guard g{ mtx_targets_ }; + auto& idx = targets_.template get(); + auto it = idx.lower_bound(std::make_tuple(ss.client_id(), share_name)); + if (it == idx.end() || (it->share_name != share_name || it->client_id() != ss.client_id())) { + it = idx.emplace_hint(it, force_move(share_name), ss, std::chrono::steady_clock::now()); + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& st = const_cast(*it); + bool inserted; + std::tie(std::ignore, inserted) = st.tf_subs.emplace(force_move(topic_filter), force_move(sub)); + BOOST_ASSERT(inserted); + } else { + // entry exists + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& st = const_cast(*it); + st.tf_subs.emplace(force_move(topic_filter), force_move(sub)); // ignore overwrite + } +} + +template +inline void shared_target::erase(std::string share_name, std::string topic_filter, session_state const& ss) { + std::lock_guard g{ mtx_targets_ }; + auto& idx = targets_.template get(); + auto it = idx.find(std::make_tuple(ss.client_id(), share_name)); + if (it == idx.end()) { + ASYNC_MQTT_LOG("mqtt_broker", warning) + << "attempt to erase non exist entry" + << " share_name:" << share_name << " topic_filtere:" << topic_filter << " client_id:" << ss.client_id(); + return; + } + + // entry exists + + // const_cast is appropriate here + // See https://github.com/boostorg/multi_index/issues/50 + auto& st = const_cast(*it); + st.tf_subs.erase(topic_filter); + if (it->tf_subs.empty()) { + idx.erase(it); + } +} + +template +inline void shared_target::erase(session_state const& ss) { + std::lock_guard g{ mtx_targets_ }; + auto& idx = targets_.template get(); + auto r = idx.equal_range(ss.client_id()); + idx.erase(r.first, r.second); +} + +template +inline std::optional, subscription>> shared_target::get_target( + std::string const& share_name, + std::string const& topic_filter) { + std::lock_guard g{ mtx_targets_ }; + // get share_name matched range ordered by timestamp (ascending) + auto& idx = targets_.template get(); + auto r = idx.equal_range(share_name); + for (; r.first != r.second; ++r.first) { + auto const& elem = *r.first; + auto it = elem.tf_subs.find(topic_filter); + + // no share_name/topic_filter matched + if (it == elem.tf_subs.end()) + continue; + + // matched + // update timestamp (timestamp is key) + idx.modify(r.first, [](auto& e) { e.tp = std::chrono::steady_clock::now(); }); + return std::make_tuple(elem.ssr, it->second); + } + return std::nullopt; +} + +template +inline shared_target::entry::entry(std::string share_name, + session_state& ss, + std::chrono::time_point tp) + : share_name{ force_move(share_name) }, ssr{ ss }, tp{ force_move(tp) } {} + +template +inline std::string const& shared_target::entry::client_id() const { + return ssr.get().client_id(); +} + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_SHARED_TARGET_IMPL_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/sub_con_map.hpp b/exes/mqtt-bridge/tests/inc/broker/sub_con_map.hpp new file mode 100644 index 0000000000..123949c6ae --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/sub_con_map.hpp @@ -0,0 +1,21 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_SUB_CON_MAP_HPP) +#define ASYNC_MQTT_BROKER_SUB_CON_MAP_HPP + +#include +#include + +namespace async_mqtt { + +template +using sub_con_map = multiple_subscription_map>; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_SUB_CON_MAP_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/subscription.hpp b/exes/mqtt-bridge/tests/inc/broker/subscription.hpp new file mode 100644 index 0000000000..70b443425e --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/subscription.hpp @@ -0,0 +1,43 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_SUBSCRIPTION_HPP) +#define ASYNC_MQTT_BROKER_SUBSCRIPTION_HPP + +#include +#include + +#include +#include +#include + +namespace async_mqtt { + +template +struct subscription { + subscription(session_state_ref ss, + std::string sharename, + std::string topic, + sub::opts opts, + std::optional sid) + : ss{ ss }, sharename{ force_move(sharename) }, topic{ force_move(topic) }, opts{ opts }, sid{ sid } {} + + session_state_ref ss; + std::string sharename; + std::string topic; + sub::opts opts; + std::optional sid; +}; + +template +inline bool operator<(subscription const& lhs, subscription const& rhs) { + return &lhs.ss.get() < &rhs.ss.get(); +} + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_SUBSCRIPTION_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/subscription_map.hpp b/exes/mqtt-bridge/tests/inc/broker/subscription_map.hpp new file mode 100644 index 0000000000..8b117dc703 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/subscription_map.hpp @@ -0,0 +1,681 @@ +#pragma once +// Copyright Wouter van Kleunen 2019 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_SUBSCRIPTION_MAP_HPP) +#define ASYNC_MQTT_BROKER_SUBSCRIPTION_MAP_HPP + +#include +#include +#include + +#include +#include + +#include + +#include + +namespace async_mqtt { + +/** + * + * In MQTT we have: + * Clients subscribed with certain topic filters, topic filters are path with may contain wildcards such as + * + and # + * . A subscription to "#" will not receive any messages published to a topic beginning with a $ + * · A subscription to "+/monitor/Clients" will not receive any messages published to "$SYS/monitor/Clients" + * · A subscription to "$SYS/#" will receive messages published to topics beginning with "$SYS/" + * · A subscription to "$SYS/monitor/Clients/+" will receive messages published to "$SYS/monitor/Clients/" + * · For a Client to receive messages from topics that begin with $SYS/ and from topics that don’t begin with a $, it + * has to subscribe to both “#” and “$SYS/#” Check whether a string is a valid subscription using 'mqtt_valid_subscription' + * + * Topics being published, a topic is a sort of path and does not contain wildcards + * · $SYS/ has been widely adopted as a prefix to topics that contain Server-specific information or control APIs + * · Applications cannot use a topic with a leading $ character for their own purposes + * Check whether a string is a valid topic using 'mqtt_valid_topic' + * + * + * We introduce two data structures: + * . A subscription map, storing a topic_filter -> data + * Using a published topic, we can find all topic filters which match the specified topic + * . A stored topic map, storing topic -> data + * Using a new topic filter, we can find all stored topics which match the specified topic filter + * + * Subscription map stores all entries in a tree + * the tree starts from a root node, and topic filters are tokenized and stored in the tree + * + * For example if the topic_filter example/monitor/Clients is stored, the following nodes are created: + * root -> example -> monitor -> Clients + * + * Every node in the tree may store one or multiple subscribers. Nodes store a reference count to the number of subscribers + * so for example, if we store the following topic_filters: + * example/ + * example/monitor/Clients + * + * the subscription map looks as follows: + * root(2) -> example(2) -> monitor(1) -> Clients (1) + * + * hash and + are stored as normal nodes within the tree, but the parent node knows if a hash child is available. This + * improves the matching, no extra lookup is required to see if a # or + child is available in a child node: + * + * example/# + * + * stores the following tree: + * root -> example (hash: yes) -> # + * + * and + * + * example/+ + * + * stores the following tree: + * root -> example (plus: yes) -> # + * + * all node entries are stored in a single hash map. The key for every node is: (parent node id, path) + * + * so if we store: root/example/test + * root (id:1) -> example (id:2, key:1,example) -> test (id:3, key:2,test) + * + * also, every node stores the key of its parent, allowing quick traversing from leaf to root of the tree + */ + +// Combined storage for count and flags +// we can have 32bit or 64bit version + +// Compile error on other platforms (not 32 or 64 bit) +template +struct count_storage { + static_assert(N == 4 || N == 8, + "Subscription map count_storage only knows how to handle architectures with 32 or 64 bit size_t: please " + "update to support your platform."); +}; + +template <> +struct count_storage<4> { +public: + count_storage(std::uint32_t v = 1) : value_(v & 0x3fffffffUL), has_hash_child_(false), has_plus_child_(false) {} + + static constexpr std::size_t max() { return std::numeric_limits::max() >> 2; } + + std::uint32_t value() const { return value_; } + void set_value(std::uint32_t v) { value_ = v & 0x3fffffffUL; } + void increment_value() { ++value_; } + void decrement_value() { --value_; } + + bool has_hash_child() const { return has_hash_child_; } + void set_hash_child(bool v) { has_hash_child_ = v; } + + bool has_plus_child() const { return has_plus_child_; } + void set_plus_child(bool v) { has_plus_child_ = v; } + +private: + std::uint32_t value_ : 30; + std::uint32_t has_hash_child_ : 1; + std::uint32_t has_plus_child_ : 1; +}; + +template <> +struct count_storage<8> { +public: + count_storage(std::uint64_t v = 1) : value_(v & 0x3fffffffffffffffULL), has_hash_child_(false), has_plus_child_(false) {} + + static constexpr std::uint64_t max() { return std::numeric_limits::max() >> 2; } + + std::uint64_t value() const { return value_; } + void set_value(std::uint64_t v) { value_ = v & 0x3fffffffffffffffULL; } + void increment_value() { ++value_; } + void decrement_value() { --value_; } + + bool has_hash_child() const { return has_hash_child_; } + void set_hash_child(bool v) { has_hash_child_ = v; } + + bool has_plus_child() const { return has_plus_child_; } + void set_plus_child(bool v) { has_plus_child_ = v; } + +private: + std::uint64_t value_ : 62; + std::uint64_t has_hash_child_ : 1; + std::uint64_t has_plus_child_ : 1; +}; + +template +class subscription_map_base { +public: + using node_id_type = std::size_t; + using path_entry_key = std::pair; + using handle = path_entry_key; + +private: + // Generate a node id for a new node + node_id_type generate_node_id() { + if (next_node_id == std::numeric_limits::max()) + throw_max_stored_topics(); + return ++next_node_id; + } + + using count_storage_type = count_storage; + + struct path_entry { + node_id_type id; + path_entry_key parent; + + count_storage_type count; + + Value value; + + path_entry(node_id_type id, path_entry_key parent) : id(id), parent(parent) {} + }; + + // Increase the subscription count for a specific node + static void increase_count_storage(count_storage_type& count) { + if (count.value() == count_storage_type::max()) { + throw_max_stored_topics(); + } + + count.increment_value(); + } + + // Decrease the subscription count for a specific node + static void decrease_count_storage(count_storage_type& count) { + BOOST_ASSERT(count.value() > 0); + count.decrement_value(); + } + + using this_type = subscription_map_base; + + // Use boost hash to hash pair in path_entry_key + using map_type = std::unordered_map>; + + map_type map; + using map_type_iterator = typename map_type::iterator; + using map_type_const_iterator = typename map_type::const_iterator; + + node_id_type next_node_id = 0; + +protected: + // Key and id of the root key + path_entry_key root_key; + node_id_type root_node_id; + + // Return the iterator of the root + map_type_iterator get_root() { return map.find(root_key); } + map_type_const_iterator get_root() const { return map.find(root_key); } + + // Map size tracks the total number of subscriptions within the map + size_t map_size = 0; + + map_type_iterator get_key(path_entry_key key) { return map.find(key); } + map_type_iterator begin() { return map.begin(); } + map_type_iterator end() { return map.end(); } + map_type const& get_map() const { return map; } + + handle path_to_handle(std::vector const& path) const { return path.back()->first; } + + std::vector find_topic_filter(std::string_view topic_filter) { + auto parent_id = get_root()->second.id; + std::vector path; + + topic_filter_tokenizer(topic_filter, [this, &path, &parent_id](std::string_view t) mutable { + auto entry = map.find(path_entry_key(parent_id, t)); + + if (entry == map.end()) { + path.clear(); + return false; + } + + path.push_back(entry); + parent_id = entry->second.id; + return true; + }); + + return path; + } + + std::vector create_topic_filter(std::string_view topic_filter) { + auto parent = get_root(); + + std::vector result; + + topic_filter_tokenizer(topic_filter, [this, &parent, &result](std::string_view t) mutable { + auto entry = map.find(path_entry_key(parent->second.id, t)); + + if (entry == map.end()) { + entry = map.emplace(path_entry_key(parent->second.id, t), path_entry(generate_node_id(), parent->first)).first; + + parent->second.count.set_plus_child(parent->second.count.has_plus_child() || (t == "+")); + parent->second.count.set_hash_child(parent->second.count.has_hash_child() || (t == "#")); + } else { + increase_count_storage(entry->second.count); + } + + result.push_back(entry); + parent = entry; + return true; + }); + + return result; + } + + // Remove a value at the specified path + void remove_topic_filter(std::vector const& path) { + bool remove_plus_child_flag = false; + bool remove_hash_child_flag = false; + + // Go through entries to remove + for (auto& entry : boost::adaptors::reverse(path)) { + if (remove_plus_child_flag) { + entry->second.count.set_plus_child(false); + remove_plus_child_flag = false; + } + + if (remove_hash_child_flag) { + entry->second.count.set_hash_child(false); + remove_hash_child_flag = false; + } + + decrease_count_storage(entry->second.count); + if (entry->second.count.value() == 0) { + remove_plus_child_flag = (entry->first.second == "+"); + remove_hash_child_flag = (entry->first.second == "#"); + + // Erase in unordered map only invalidates erased iterator + // other iterators are unaffected + map.erase(entry->first); + } + } + + auto root = get_root(); + if (remove_plus_child_flag) { + root->second.count.set_plus_child(false); + } + + if (remove_hash_child_flag) { + root->second.count.set_hash_child(false); + } + } + + template + static void find_match_impl(ThisType& self, std::string_view topic, Output&& callback) { + using iterator_type = decltype(self.map.end()); // const_iterator or iterator depends on self + + std::vector entries; + entries.push_back(self.get_root()); + + topic_filter_tokenizer(topic, [&self, &entries, &callback](std::string_view t) { + std::vector new_entries; + + for (auto& entry : entries) { + auto parent = entry->second.id; + auto i = self.map.find(path_entry_key(parent, t)); + if (i != self.map.end()) { + new_entries.push_back(i); + } + + if (entry->second.count.has_plus_child()) { + i = self.map.find(path_entry_key(parent, std::string_view("+"))); + if (i != self.map.end()) { + if (parent != self.root_node_id || t.empty() || t[0] != '$') { + new_entries.push_back(i); + } + } + } + + if (entry->second.count.has_hash_child()) { + i = self.map.find(path_entry_key(parent, std::string_view("#"))); + if (i != self.map.end()) { + if (parent != self.root_node_id || t.empty() || t[0] != '$') { + callback(i->second.value); + } + } + } + } + + std::swap(entries, new_entries); + return !entries.empty(); + }); + + for (auto& entry : entries) { + callback(entry->second.value); + } + } + + // Find all topic filters that match the specified topic + template + void find_match(std::string_view topic, Output&& callback) const { + find_match_impl(*this, topic, std::forward(callback)); + } + + // Find all topic filters and allow modification + template + void modify_match(std::string_view topic, Output&& callback) { + find_match_impl(*this, topic, std::forward(callback)); + } + + template + static void handle_to_iterators(ThisType& self, handle const& h, Output&& output) { + auto i = h; + while (i != self.root_key) { + auto entry_iter = self.map.find(i); + if (entry_iter == self.map.end()) { + throw_invalid_handle(); + } + + output(entry_iter); + i = entry_iter->second.parent; + } + } + + // Exceptions used + static void throw_invalid_topic_filter() { + throw std::runtime_error("Subscription map invalid topic filter was specified"); + } + static void throw_invalid_handle() { throw std::runtime_error("Subscription map invalid handle was specified"); } + static void throw_max_stored_topics() { + throw std::overflow_error("Subscription map maximum number of stored topic filters reached"); + } + + // Get the iterators of a handle + std::vector handle_to_iterators(handle const& h) { + std::vector result; + handle_to_iterators(*this, h, [&result](map_type_iterator i) { result.push_back(i); }); + std::reverse(result.begin(), result.end()); + return result; + } + + // Increase the number of subscriptions for this handle + void increase_subscriptions(handle const& h) { + handle_to_iterators(*this, h, [](map_type_iterator i) { increase_count_storage(i->second.count); }); + } + + // Increase the map size (total number of subscriptions stored) + void increase_map_size() { + if (map_size == std::numeric_limits::max()) { + throw_max_stored_topics(); + } + + ++map_size; + } + + // Decrease the map size (total number of subscriptions stored) + void decrease_map_size() { + BOOST_ASSERT(map_size > 0); + --map_size; + } + + // Increase the number of subscriptions for this path + void increase_subscriptions(std::vector const& path) { + for (auto i : path) { + increase_count_storage(i->second.count); + } + } + + subscription_map_base() { + // Create the root node + root_node_id = generate_node_id(); + root_key = path_entry_key(generate_node_id(), buffer()); + map.emplace(root_key, path_entry(root_node_id, path_entry_key())); + } + +public: + // Return the number of elements in the tree + std::size_t internal_size() const { return map.size(); } + + // Return the number of registered topic filters + std::size_t size() const { return this->map_size; } + + // Lookup a topic filter + std::optional lookup(std::string_view topic_filter) { + auto path = this->find_topic_filter(topic_filter); + if (path.empty()) + return std::optional(); + else + return this->path_to_handle(force_move(path)); + } + + // Get path of topic_filter + std::string handle_to_topic_filter(handle const& h) const { + std::string result; + + handle_to_iterators(*this, h, [&result](map_type_const_iterator i) { + if (result.empty()) { + result = std::string(i->first.second); + } else { + result = std::string(i->first.second) + "/" + result; + } + }); + + return result; + } +}; + +template +class single_subscription_map : public subscription_map_base> { +public: + // Handle of an entry + using handle = typename subscription_map_base::handle; + + // Insert a value at the specified topic_filter + template + std::pair insert(std::string_view topic_filter, V&& value) { + auto existing_subscription = this->find_topic_filter(topic_filter); + if (!existing_subscription.empty()) { + if (existing_subscription.back()->second.value) + return std::make_pair(this->path_to_handle(force_move(existing_subscription)), false); + + existing_subscription.back()->second.value.emplace(std::forward(value)); + return std::make_pair(this->path_to_handle(force_move(existing_subscription)), true); + } + + auto new_topic_filter = this->create_topic_filter(topic_filter); + new_topic_filter.back()->second.value = value; + this->increase_map_size(); + return std::make_pair(this->path_to_handle(force_move(new_topic_filter)), true); + } + + // Update a value at the specified topic filter + template + void update(std::string_view topic_filter, V&& value) { + auto path = this->find_topic_filter(topic_filter); + if (path.empty()) { + this->throw_invalid_topic_filter(); + } + + path.back()->second.value.emplace(std::forward(value)); + } + + template + void update(handle const& h, V&& value) { + auto entry_iter = this->get_key(h); + if (entry_iter == this->end()) { + this->throw_invalid_topic_filter(); + } + entry_iter->second.value.emplace(std::forward(value)); + } + + // Remove a value at the specified topic filter + std::size_t erase(std::string_view topic_filter) { + auto path = this->find_topic_filter(topic_filter); + if (path.empty() || !path.back()->second.value) { + return 0; + } + + this->remove_topic_filter(path); + this->decrease_map_size(); + return 1; + } + + // Remove a value using a handle + std::size_t erase(handle const& h) { + auto path = this->handle_to_iterators(h); + if (path.empty() || !path.back()->second.value) { + return 0; + } + + this->remove_topic_filter(path); + this->decrease_map_size(); + return 1; + } + + // Find all topic filters that match the specified topic + template + void find(std::string_view topic, Output&& callback) const { + this->find_match(topic, [&callback](std::optional const& value) { + if (value) { + callback(*value); + } + }); + } +}; + +template , + class Pred = std::equal_to, + class Cont = std::unordered_map>>> +class multiple_subscription_map : public subscription_map_base { +public: + using container_t = Cont; + + // Handle of an entry + using handle = typename subscription_map_base::handle; + + // Insert a key => value at the specified topic filter + // returns the handle and true if key was inserted, false if key was updated + template + std::pair insert_or_assign(std::string_view topic_filter, K&& key, V&& value) { + auto path = this->find_topic_filter(topic_filter); + if (path.empty()) { + auto new_topic_filter = this->create_topic_filter(topic_filter); + new_topic_filter.back()->second.value.emplace(std::forward(key), std::forward(value)); + this->increase_map_size(); + return std::make_pair(this->path_to_handle(force_move(new_topic_filter)), true); + } else { + auto& subscription_set = path.back()->second.value; + +#if __cpp_lib_unordered_map_try_emplace >= 201411L + auto insert_result = subscription_set.insert_or_assign(std::forward(key), std::forward(value)); + if (insert_result.second) { + this->increase_subscriptions(path); + this->increase_map_size(); + } + return std::make_pair(this->path_to_handle(force_move(path)), insert_result.second); +#else + auto iter = subscription_set.find(key); + if (iter == subscription_set.end()) { + subscription_set.emplace(std::forward(key), std::forward(value)); + this->increase_subscriptions(path); + this->increase_map_size(); + } else { + iter->second = std::forward(value); + } + return std::make_pair(this->path_to_handle(force_move(path)), iter == subscription_set.end()); + +#endif + } + } + + // Insert a key => value with a handle to the topic filter + // returns the handle and true if key was inserted, false if key was updated + template + std::pair insert_or_assign(handle const& h, K&& key, V&& value) { + auto h_iter = this->get_key(h); + if (h_iter == this->end()) { + this->throw_invalid_handle(); + } + + auto& subscription_set = h_iter->second.value; + +#if __cpp_lib_unordered_map_try_emplace >= 201411L + auto insert_result = subscription_set.insert_or_assign(std::forward(key), std::forward(value)); + if (insert_result.second) { + this->increase_subscriptions(h); + this->increase_map_size(); + } + return std::make_pair(h, insert_result.second); +#else + auto iter = subscription_set.find(key); + if (iter == subscription_set.end()) { + subscription_set.emplace(std::forward(key), std::forward(value)); + this->increase_subscriptions(h); + this->increase_map_size(); + } else { + iter->second = std::forward(value); + } + return std::make_pair(h, iter == subscription_set.end()); +#endif + } + + // Remove a value at the specified handle + // returns the number of removed elements + std::size_t erase(handle const& h, Key const& key) { + // Find the handle in the map + auto h_iter = this->get_key(h); + if (h_iter == this->end()) { + this->throw_invalid_handle(); + } + + // Remove the specified value + auto result = h_iter->second.value.erase(key); + if (result) { + this->remove_topic_filter(this->handle_to_iterators(h)); + this->decrease_map_size(); + } + + return result; + } + + // Remove a value at the specified topic filter + // returns the number of removed elements + std::size_t erase(std::string_view topic_filter, Key const& key) { + // Find the topic filter in the map + auto path = this->find_topic_filter(topic_filter); + if (path.empty()) { + return 0; + } + + // Remove the specified value + auto result = path.back()->second.value.erase(key); + if (result) { + this->decrease_map_size(); + this->remove_topic_filter(path); + } + + return result; + } + + // Find all topic filters that match the specified topic + template + void find(std::string_view topic, Output&& callback) const { + this->find_match(topic, [&callback](Cont const& values) { + for (auto const& i : values) { + callback(i.first, i.second); + } + }); + } + + // Find all topic filters that match and allow modification + template + void modify(std::string_view topic, Output&& callback) { + this->modify_match(topic, [&callback](Cont& values) { + for (auto& i : values) { + callback(i.first, i.second); + } + }); + } + + template + void dump(Output& out) { + out << "Root node id: " << this->root_node_id << std::endl; + for (auto const& i : this->get_map()) { + out << "(" << i.first.first << ", " << i.first.second << "): id: " << i.second.id + << ", size: " << i.second.value.size() << ", value: " << i.second.count.value << std::endl; + } + } +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_SUBSCRIPTION_MAP_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/tags.hpp b/exes/mqtt-bridge/tests/inc/broker/tags.hpp new file mode 100644 index 0000000000..cf7616c8b3 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/tags.hpp @@ -0,0 +1,27 @@ +#pragma once +// Copyright Takatoshi Kondo 2020 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_TAGS_HPP) +#define ASYNC_MQTT_BROKER_TAGS_HPP + +namespace async_mqtt { + +struct tag_seq {}; +struct tag_con {}; +struct tag_topic {}; +struct tag_topic_filter {}; +struct tag_con_topic_filter {}; +struct tag_cid {}; +struct tag_cid_topic_filter {}; +struct tag_tim {}; +struct tag_pid {}; +struct tag_sn_tp {}; +struct tag_cid_sn {}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_TAGS_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/topic_filter.hpp b/exes/mqtt-bridge/tests/inc/broker/topic_filter.hpp new file mode 100644 index 0000000000..7832c40081 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/topic_filter.hpp @@ -0,0 +1,304 @@ +#pragma once +// Copyright wkl04 2019 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_TOPIC_FILTER_HPP) +#define ASYNC_MQTT_BROKER_TOPIC_FILTER_HPP + +#include +#include +#include +#include + +#include + +#include + +namespace async_mqtt { + +static constexpr char topic_filter_separator = '/'; + +template +inline Iterator topic_filter_tokenizer_next(Iterator first, Iterator last) { + return std::find(first, last, topic_filter_separator); +} + +template +inline std::size_t topic_filter_tokenizer(Iterator first, Iterator last, Output write) { + std::size_t count = 1; + auto pos = topic_filter_tokenizer_next(first, last); + while (write(first, pos) && pos != last) { + first = std::next(pos); + pos = topic_filter_tokenizer_next(first, last); + ++count; + } + return count; +} + +template +inline std::size_t topic_filter_tokenizer(std::string_view str, Output write) { + return topic_filter_tokenizer( + std::begin(str), std::end(str), + [&write](std::string_view::const_iterator token_begin, std::string_view::const_iterator token_end) { + return write(make_string_view(token_begin, token_end)); + }); +} + +// TODO: Technically this function is simply wrong, since it's treating the +// topic pattern as if it were an ASCII sequence. +// To make this function correct per the standard, it would be necessary +// to conduct the search for the wildcard characters using a proper +// UTF-8 API to avoid problems of interpreting parts of multi-byte characters +// as if they were individual ASCII characters +constexpr bool validate_topic_filter(std::string_view topic_filter) { + /* + * Confirm the topic pattern is valid before registering it. + * Use rules from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106 + */ + + // All Topic Names and Topic Filters MUST be at least one character long + // Topic Names and Topic Filters are UTF-8 Encoded Strings; they MUST NOT encode to more than 65,535 bytes + if (topic_filter.empty() || (topic_filter.size() > std::numeric_limits::max())) { + return false; + } + + for (std::string_view::size_type idx = topic_filter.find_first_of(std::string_view("\0+#", 3)); + std::string_view::npos != idx; idx = topic_filter.find_first_of(std::string_view("\0+#", 3), idx + 1)) { + BOOST_ASSERT(('\0' == topic_filter[idx]) || ('+' == topic_filter[idx]) || ('#' == topic_filter[idx])); + if ('\0' == topic_filter[idx]) { + // Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000) + return false; + } else if ('+' == topic_filter[idx]) { + /* + * Either must be the first character, + * or be preceeded by a topic seperator. + */ + if ((0 != idx) && ('/' != topic_filter[idx - 1])) { + return false; + } + + /* + * Either must be the last character, + * or be followed by a topic seperator. + */ + if ((topic_filter.size() - 1 != idx) && ('/' != topic_filter[idx + 1])) { + return false; + } + } + // multilevel wildcard + else if ('#' == topic_filter[idx]) { + /* + * Must be absolute last character. + * Must only be one multi level wild card. + */ + if (idx != topic_filter.size() - 1) { + return false; + } + + /* + * If not the first character, then the + * immediately preceeding character must + * be a topic level separator. + */ + if ((0 != idx) && ('/' != topic_filter[idx - 1])) { + return false; + } + } else { + return false; + } + } + return true; +} + +// The following rules come from https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901247 +static_assert(!validate_topic_filter(""), "All Topic Names and Topic Filters MUST be at least one character long"); +static_assert(validate_topic_filter("/"), "A Topic Name or Topic Filter consisting only of the ‘/’ character is valid"); +static_assert(!validate_topic_filter(std::string_view("\0", 1)), + "Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000)"); +static_assert(validate_topic_filter(" "), "Topic Names and Topic Filters can include the space character"); +static_assert(validate_topic_filter("/////"), + "Topic level separators can appear anywhere in a Topic Filter or Topic Name. Adjacent Topic level separators " + "indicate a zero-length topic level"); +static_assert(validate_topic_filter("#"), + "The multi-level wildcard character MUST be specified either on its own or following a topic level separator"); +static_assert(validate_topic_filter("/#"), + "The multi-level wildcard character MUST be specified either on its own or following a topic level separator"); +static_assert(validate_topic_filter("+/#"), + "The multi-level wildcard character MUST be specified either on its own or following a topic level separator"); +static_assert(!validate_topic_filter("+#"), + "The multi-level wildcard character MUST be specified either on its own or following a topic level separator"); +static_assert(!validate_topic_filter("++"), + "The multi-level wildcard character MUST be specified either on its own or following a topic level separator"); +static_assert(!validate_topic_filter("f#"), + "The multi-level wildcard character MUST be specified either on its own or following a topic level separator"); +static_assert(!validate_topic_filter("#/"), + "In either case the multi-level wildcard character MUST be the last character specified in the Topic Filter"); + +static_assert(validate_topic_filter("+"), + "The single-level wildcard can be used at any level in the Topic Filter, including first and last levels"); +static_assert(validate_topic_filter("+/bob/alice/sue"), + "The single-level wildcard can be used at any level in the Topic Filter, including first and last levels"); +static_assert(validate_topic_filter("bob/alice/sue/+"), + "The single-level wildcard can be used at any level in the Topic Filter, including first and last levels"); +static_assert(validate_topic_filter("+/bob/alice/sue/+"), + "The single-level wildcard can be used at any level in the Topic Filter, including first and last levels"); +static_assert(validate_topic_filter("+/bob/+/sue/+"), + "The single-level wildcard can be used at any level in the Topic Filter, including first and last levels"); +static_assert(validate_topic_filter("+/bob/+/sue/#"), + "The single-level wildcard can be used at more than one level in the Topic Filter and can be used in " + "conjunction with the multi-level wildcard"); +static_assert(!validate_topic_filter("+a"), + "Where it is used, the single-level wildcard MUST occupy an entire level of the filter."); +static_assert(!validate_topic_filter("a+"), + "Where it is used, the single-level wildcard MUST occupy an entire level of the filter."); +static_assert(!validate_topic_filter("/a+"), + "Where it is used, the single-level wildcard MUST occupy an entire level of the filter."); +static_assert(!validate_topic_filter("a+/"), + "Where it is used, the single-level wildcard MUST occupy an entire level of the filter."); +static_assert(!validate_topic_filter("/a+/"), + "Where it is used, the single-level wildcard MUST occupy an entire level of the filter."); + +constexpr bool validate_topic_name(std::string_view topic_name) { + /* + * Confirm the topic name is valid + * Use rules from https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901247 + */ + + // All Topic Names and Topic Filters MUST be at least one character long + // Topic Names and Topic Filters are UTF-8 Encoded Strings; they MUST NOT encode to more than 65,535 bytes + // The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name + // Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000) + return !topic_name.empty() && (topic_name.size() <= std::numeric_limits::max()) && + (std::string_view::npos == topic_name.find_first_of(std::string_view("\0+#", 3))); +} + +// The following rules come from https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901247 +static_assert(!validate_topic_name(""), "All Topic Names and Topic Filters MUST be at least one character long"); +static_assert(validate_topic_name("/"), "A Topic Name or Topic Filter consisting only of the ‘/’ character is valid"); +static_assert(!validate_topic_name(std::string_view("\0", 1)), + "Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000)"); +static_assert(validate_topic_name(" "), "Topic Names and Topic Filters can include the space character"); +static_assert(validate_topic_name("/////"), + "Topic level separators can appear anywhere in a Topic Filter or Topic Name. Adjacent Topic level separators " + "indicate a zero-length topic level"); +static_assert(!validate_topic_name("#"), + "The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name"); +static_assert(!validate_topic_name("+"), + "The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name"); +static_assert(!validate_topic_name("/#"), + "The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name"); +static_assert(!validate_topic_name("+/#"), + "The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name"); +static_assert(!validate_topic_name("f#"), + "The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name"); +static_assert(!validate_topic_name("#/"), + "The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name"); + +constexpr bool compare_topic_filter(std::string_view topic_filter, std::string_view topic_name) { + if (!validate_topic_filter(topic_filter)) { + BOOST_ASSERT(validate_topic_filter(topic_filter)); + return false; + } + + if (!validate_topic_name(topic_name)) { + BOOST_ASSERT(validate_topic_name(topic_name)); + return false; + } + + // TODO: The Server MUST NOT match Topic Filters starting with a wildcard character (# or +) with Topic Names beginning + // with a $ character + for (std::string_view::size_type idx = topic_filter.find_first_of("+#"); std::string_view::npos != idx; + idx = topic_filter.find_first_of("+#")) { + BOOST_ASSERT(('+' == topic_filter[idx]) || ('#' == topic_filter[idx])); + + if ('+' == topic_filter[idx]) { + // Compare everything up to the first + + if (topic_filter.substr(0, idx) == topic_name.substr(0, idx)) { + /* + * We already know thanks to the topic filter being validated + * that the + symbol is directly touching '/'s on both sides + * (if not the first or last character), so we don't need to + * double check that. + * + * By simply removing the prefix that we've compared and letting + * the loop continue, we get the proper comparison of the '/'s + * automatically when the loop continues. + */ + topic_filter.remove_prefix(idx + 1); + /* + * It's a bit more complicated for the incoming topic though + * as we need to remove everything up to the next seperator. + */ + topic_name.remove_prefix(topic_name.find('/', idx)); + } else { + return false; + } + } + // multilevel wildcard + else { + /* + * Compare up to where the multilevel wild card is found + * and then anything after that matches the wildcard. + */ + return topic_filter.substr(0, idx) == topic_name.substr(0, idx); + } + } + + // No + or # found in the remaining topic filter. Just do a string compare. + return topic_filter == topic_name; +} + +static_assert(compare_topic_filter("bob", "bob"), "Topic Names and Topic Filters are case sensitive"); +static_assert(!compare_topic_filter("Bob", "bob"), "Topic Names and Topic Filters are case sensitive"); +static_assert(!compare_topic_filter("bob", "boB"), "Topic Names and Topic Filters are case sensitive"); +static_assert(!compare_topic_filter("/bob", "bob"), + "A leading or trailing ‘/’ creates a distinct Topic Name or Topic Filter"); +static_assert(!compare_topic_filter("bob/", "bob"), + "A leading or trailing ‘/’ creates a distinct Topic Name or Topic Filter"); +static_assert(!compare_topic_filter("bob", "/bob"), + "A leading or trailing ‘/’ creates a distinct Topic Name or Topic Filter"); +static_assert(!compare_topic_filter("bob", "bob/"), + "A leading or trailing ‘/’ creates a distinct Topic Name or Topic Filter"); +static_assert(compare_topic_filter("bob/alice", "bob/alice"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(compare_topic_filter("bob/alice/sue", "bob/alice/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(compare_topic_filter("bob//////sue", "bob//////sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(compare_topic_filter("bob/#", "bob//////sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(!compare_topic_filter("bob///#", "bob/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(compare_topic_filter("bob/+/sue", "bob/alice/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(!compare_topic_filter("bob/+/sue", "bob/alice/mary/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(compare_topic_filter("#", "bob/alice/mary/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(compare_topic_filter("bob/#", "bob/alice/mary/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(compare_topic_filter("bob/alice/#", "bob/alice/mary/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(compare_topic_filter("bob/alice/mary/#", "bob/alice/mary/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); +static_assert(!compare_topic_filter("bob/alice/mary/sue/#", "bob/alice/mary/sue"), + "Each non-wildcarded level in the Topic Filter has to match the corresponding level in the Topic Name " + "character for character for the match to succeed"); + +} // namespace async_mqtt + +#endif // MQTT_BROKER_TOPIC_FILTER_HPP diff --git a/exes/mqtt-bridge/tests/inc/broker/uuid.hpp b/exes/mqtt-bridge/tests/inc/broker/uuid.hpp new file mode 100644 index 0000000000..f7c21e5fc7 --- /dev/null +++ b/exes/mqtt-bridge/tests/inc/broker/uuid.hpp @@ -0,0 +1,27 @@ +#pragma once +// Copyright Takatoshi Kondo 2021 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_BROKER_UUID_HPP) +#define ASYNC_MQTT_BROKER_UUID_HPP + +#include + +#include +#include +#include + +namespace async_mqtt { + +inline std::string create_uuid_string() { + // See https://github.com/boostorg/uuid/issues/121 + thread_local boost::uuids::random_generator gen; + return boost::uuids::to_string(gen()); +} + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_BROKER_UUID_HPP diff --git a/exes/mqtt-bridge/tests/inc/endpoint_mock.hpp b/exes/mqtt-bridge/tests/inc/endpoint_mock.hpp index 53ae6eaa55..7b65095d5f 100644 --- a/exes/mqtt-bridge/tests/inc/endpoint_mock.hpp +++ b/exes/mqtt-bridge/tests/inc/endpoint_mock.hpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -62,33 +63,32 @@ using boost::asio::experimental::awaitable_operators::operator||; class endpoint_client_mock { public: - explicit endpoint_client_mock(asio::io_context& ctx, structs::ssl_active_e) : strand_(asio::make_strand(ctx)) {} + explicit endpoint_client_mock(asio::io_context& ctx, structs::ssl_active_e) : ctx_(ctx) {} - auto strand() -> asio::strand& { return strand_; } + auto get_executor() const -> asio::any_io_executor { return ctx_.get_executor(); } auto recv(async_mqtt::control_packet_type packet_t) -> asio::awaitable< package_v> { if (packet_t == async_mqtt::control_packet_type::suback) { package_v my_variant; - my_variant.set({ 0, { async_mqtt::suback_reason_code::granted_qos_0 } }); + my_variant.set( + async_mqtt::v5::suback_packet{ 0, { async_mqtt::suback_reason_code::granted_qos_0 } }); co_return my_variant; } else if (packet_t == async_mqtt::control_packet_type::publish) { package_v my_variant; - my_variant.set({ 0, - async_mqtt::allocate_buffer("topic"), - async_mqtt::allocate_buffer("payload"), - { async_mqtt::pub::retain::no }, - async_mqtt::properties{} }); + my_variant.set(async_mqtt::v5::publish_packet{ + 0, "topic", "payload", { async_mqtt::pub::retain::no }, async_mqtt::properties{} }); co_return my_variant; } package_v my_variant; - my_variant.set({ true, async_mqtt::connect_reason_code::success }); + my_variant.set( + async_mqtt::v5::connack_packet{ true, async_mqtt::connect_reason_code::success }); co_return my_variant; } template - auto send(args_t&&...) -> asio::awaitable { - co_return false; + auto send(args_t&&...) -> asio::awaitable> { + co_return std::make_tuple(std::error_code{}); } template @@ -112,7 +112,7 @@ class endpoint_client_mock { auto async_handshake() -> asio::awaitable { co_return; } private: - asio::strand strand_; + asio::io_context& ctx_; async_mqtt::tls::context tls_ctx_{ async_mqtt::tls::context::tlsv12 }; std::optional> mqtt_client_; std::optional> mqtts_client_; diff --git a/exes/mqtt-bridge/tests/src/integration_tests.cpp b/exes/mqtt-bridge/tests/src/integration_tests.cpp index 084c7c7422..2dc34528b9 100644 --- a/exes/mqtt-bridge/tests/src/integration_tests.cpp +++ b/exes/mqtt-bridge/tests/src/integration_tests.cpp @@ -1,12 +1,22 @@ -#ifdef __clang__ - #include #include #include +#include + #include #include -#include +// clang-format off +PRAGMA_CLANG_WARNING_PUSH_OFF(-Wshadow-uncaptured-local) +PRAGMA_CLANG_WARNING_PUSH_OFF(-Wshadow) +PRAGMA_CLANG_WARNING_PUSH_OFF(-Wshadow-field-in-constructor) +PRAGMA_CLANG_WARNING_PUSH_OFF(-Wdocumentation) +// clang-format on +#include +PRAGMA_CLANG_WARNING_POP +PRAGMA_CLANG_WARNING_POP +PRAGMA_CLANG_WARNING_POP +PRAGMA_CLANG_WARNING_POP #include #include @@ -70,36 +80,34 @@ class mqtt_client { auto handle_resolve(asio::ip::tcp::resolver::results_type eps) -> asio::awaitable { std::ignore = co_await async_connect(amep_->lowest_layer(), eps, asio::use_awaitable); - co_await amep_->send( + co_await amep_->async_send( async_mqtt::v5::connect_packet{ true, 0x1234, - async_mqtt::allocate_buffer("cid2"), - async_mqtt::nullopt, - async_mqtt::nullopt, - async_mqtt::nullopt, + "cid2", + std::nullopt, + std::nullopt, + std::nullopt, }, asio::use_awaitable); - co_await amep_->recv(async_mqtt::filter::match, { async_mqtt::control_packet_type::connack }, asio::use_awaitable); + co_await amep_->async_recv(async_mqtt::filter::match, { async_mqtt::control_packet_type::connack }, asio::use_awaitable); co_await send_subscribe(); } auto send_subscribe() -> asio::awaitable { - std::optional packet_id = amep_->acquire_unique_packet_id(); - auto sub_packet = - async_mqtt::v5::subscribe_packet{ packet_id.value(), - { { async_mqtt::allocate_buffer(topic_), async_mqtt::qos::at_most_once } } }; - co_await amep_->send(sub_packet, asio::use_awaitable); - co_await amep_->recv(async_mqtt::filter::match, { async_mqtt::control_packet_type::suback }, asio::use_awaitable); + std::optional packet_id = amep_->acquire_unique_packet_id(); + auto sub_packet = async_mqtt::v5::subscribe_packet{ packet_id.value(), { { topic_, async_mqtt::qos::at_most_once } } }; + co_await amep_->async_send(sub_packet, asio::use_awaitable); + co_await amep_->async_recv(async_mqtt::filter::match, { async_mqtt::control_packet_type::suback }, asio::use_awaitable); co_await receive_publish_packets(); } auto receive_publish_packets() -> asio::awaitable { while (true) { - auto p = - co_await amep_->recv(async_mqtt::filter::match, { async_mqtt::control_packet_type::publish }, asio::use_awaitable); + auto p = co_await amep_->async_recv(async_mqtt::filter::match, { async_mqtt::control_packet_type::publish }, + asio::use_awaitable); async_mqtt::v5::publish_packet const& p2 = p.template get(); - for (auto& payload : p2.payload()) { + for (auto& payload : p2.payload_as_buffer()) { messages_.push_back(payload); } } @@ -261,8 +269,3 @@ auto main(int argc, char* argv[]) -> int { return 0; } -#else -auto main() -> int { - return 0; -} -#endif diff --git a/exes/themis/inc/dbus_interface.hpp b/exes/themis/inc/dbus_interface.hpp index 1dbd7bcde8..8738020619 100644 --- a/exes/themis/inc/dbus_interface.hpp +++ b/exes/themis/inc/dbus_interface.hpp @@ -29,8 +29,13 @@ class interface { connection_->request_name(service_name.data()); interface_ = object_server_->add_unique_interface(object_path.data(), interface_name.data()); - interface_->register_method(std::string(methods::list_alarms), - [&]() -> std::string { return glz::write_json(database.list_alarms()); }); + interface_->register_method(std::string(methods::list_alarms), [&]() -> std::string { + auto const alarms_str{ glz::write_json(database.list_alarms()) }; + if (!alarms_str) { + throw dbus_error("Failed to serialize alarms"); + } + return alarms_str.value(); + }); interface_->register_method( std::string(methods::register_alarm), @@ -78,9 +83,13 @@ class interface { int64_t start, int64_t end) -> std::string { auto cstart = tfc::themis::alarm_database::timepoint_from_milliseconds(start); auto cend = tfc::themis::alarm_database::timepoint_from_milliseconds(end); - return glz::write_json(database.list_activations( + auto const activations_str{ glz::write_json(database.list_activations( locale, start_count, count, static_cast(alarm_level), - static_cast(active), cstart, cend)); + static_cast(active), cstart, cend)) }; + if (!activations_str) { + throw dbus_error("Failed to serialize activations"); + } + return activations_str.value(); }); // Signal alarm_id, current_activation, ack_status diff --git a/exes/themis/tests/themis_integration_test.cpp b/exes/themis/tests/themis_integration_test.cpp index 1da59763b8..9cb8c49072 100644 --- a/exes/themis/tests/themis_integration_test.cpp +++ b/exes/themis/tests/themis_integration_test.cpp @@ -97,7 +97,7 @@ int main(int argc, char** argv) { expect(!err) << err.message(); t.ran[0] = true; }); - t.ctx.run_for(2ms); + t.ctx.run_for(100ms); expect(t.ran[0]); t.client.list_alarms([&](const std::error_code& err, std::vector alarms) { expect(!err) << err.message(); @@ -271,7 +271,7 @@ int main(int argc, char** argv) { }; // TODO: The alarm is recreated but its state is not set again. - "Alarm loses database connection set forgotten"_test = [] { + ut::skip / "Alarm loses database connection set forgotten"_test = [] { test_setup_s* server = new test_setup_s(); test_setup_c client; info<"desc", "details"> i(client.connection, "dead_server_test"); diff --git a/libs/confman/inc/public/tfc/confman.hpp b/libs/confman/inc/public/tfc/confman.hpp index db403bb577..2d62ab918b 100644 --- a/libs/confman/inc/public/tfc/confman.hpp +++ b/libs/confman/inc/public/tfc/confman.hpp @@ -79,7 +79,9 @@ class config { auto operator->() const noexcept -> storage_t const* { return std::addressof(value()); } /// \return storage_t as json string - [[nodiscard]] auto string() const -> std::string { return glz::write_json(storage_.value()); } + [[nodiscard]] auto string() const -> std::expected { + return glz::write_json(storage_.value()); + } /// TODO can we do this differently, jsonforms requires object as root element /// an example of failure would be confman> as the json schema root element would be array @@ -95,11 +97,21 @@ class config { /// \return storage_t json schema [[nodiscard]] auto schema() const -> std::string { - return tfc::json::write_json_schema>(); + auto const value{ tfc::json::write_json_schema>() }; + if (!value.has_value()) { + logger_.error("Error writing json schema: {}", glz::format_error(value.error())); + return {}; + } + return value.value(); } auto set_changed() const noexcept -> std::error_code { - client_.set(this->string()); + auto value{ this->string() }; + if (!value.has_value()) { + logger_.error("Error writing string: {}", glz::format_error(value.error())); + return std::make_error_code(std::errc::io_error); + } + client_.set(std::move(value.value())); return storage_.set_changed(); } diff --git a/libs/confman/inc/public/tfc/confman/detail/config_dbus_client.hpp b/libs/confman/inc/public/tfc/confman/detail/config_dbus_client.hpp index 1e6161223e..8c1710233e 100644 --- a/libs/confman/inc/public/tfc/confman/detail/config_dbus_client.hpp +++ b/libs/confman/inc/public/tfc/confman/detail/config_dbus_client.hpp @@ -1,11 +1,14 @@ #pragma once +#include #include #include #include #include #include +#include + #include #include @@ -35,7 +38,7 @@ class config_dbus_client { /// \note Should only be used for testing !!! explicit config_dbus_client(dbus_connection_t); - using value_call_t = std::function; + using value_call_t = std::function()>; using schema_call_t = std::function; using change_call_t = std::function; /// \brief make dbus client using given dbus connection diff --git a/libs/confman/inc/public/tfc/confman/file_storage.hpp b/libs/confman/inc/public/tfc/confman/file_storage.hpp index 63bac2b3a9..35d1234b18 100644 --- a/libs/confman/inc/public/tfc/confman/file_storage.hpp +++ b/libs/confman/inc/public/tfc/confman/file_storage.hpp @@ -102,7 +102,11 @@ class file_storage { /// \brief generate json form of storage auto to_json() const noexcept -> std::string { std::string buffer{}; // this can throw, meaning memory error - glz::write(storage_, buffer); + auto const err{ glz::write(storage_, buffer) }; + if (err) { + logger_.error(R"(Error: "{}" writing to json)", glz::format_error(err)); + return {}; + } return buffer; } diff --git a/libs/confman/inc/public/tfc/confman/remote_change.hpp b/libs/confman/inc/public/tfc/confman/remote_change.hpp index 448a0592d6..e64caf38bd 100644 --- a/libs/confman/inc/public/tfc/confman/remote_change.hpp +++ b/libs/confman/inc/public/tfc/confman/remote_change.hpp @@ -20,8 +20,12 @@ template std::string_view key, config_storage_t&& storage, std::invocable auto&& handler) { - return set_config_impl(dbus, service, key, glz::write_json(std::forward(storage)), - std::forward(handler)); + auto const write{ glz::write_json(std::forward(storage)) }; + if (!write) { + handler(std::make_error_code(std::errc::invalid_argument)); + return; + } + return set_config_impl(dbus, service, key, write.value(), std::forward(handler)); } // todo get_config diff --git a/libs/confman/src/detail/config_dbus_client.cpp b/libs/confman/src/detail/config_dbus_client.cpp index c0abc956c1..d8000660e7 100644 --- a/libs/confman/src/detail/config_dbus_client.cpp +++ b/libs/confman/src/detail/config_dbus_client.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -55,7 +56,11 @@ void config_dbus_client::initialize() { return 1; }, [this]([[maybe_unused]] std::string const& value) -> std::string { // getter - return this->value_call_(); + const auto val{ this->value_call_() }; + if (!val) { + throw tfc::dbus::exception::runtime{ fmt::format("Unable to get value: '{}'", glz::format_error(val.error())) }; + } + return val.value(); }); dbus_interface_->register_property_r( schema_property_name_, sdbusplus::vtable::property_::const_, diff --git a/libs/confman/testing/examples/CMakeLists.txt b/libs/confman/testing/examples/CMakeLists.txt index c2e12e5465..20c7f73773 100644 --- a/libs/confman/testing/examples/CMakeLists.txt +++ b/libs/confman/testing/examples/CMakeLists.txt @@ -2,19 +2,19 @@ find_package(mp-units CONFIG REQUIRED) find_package(fmt CONFIG REQUIRED) add_executable(confman_example_simple simple.cpp) -target_link_libraries(confman_example_simple PRIVATE tfc::confman mp-units::si fmt::fmt) +target_link_libraries(confman_example_simple PRIVATE tfc::confman mp-units::systems fmt::fmt) add_executable(confman_example_nested nested.cpp) -target_link_libraries(confman_example_nested PRIVATE tfc::confman mp-units::si fmt::fmt) +target_link_libraries(confman_example_nested PRIVATE tfc::confman mp-units::systems fmt::fmt) add_executable(confman_example_array array.cpp) -target_link_libraries(confman_example_array PRIVATE tfc::confman mp-units::si fmt::fmt) +target_link_libraries(confman_example_array PRIVATE tfc::confman mp-units::systems fmt::fmt) add_executable(confman_example_variant variant.cpp) -target_link_libraries(confman_example_variant PRIVATE tfc::confman mp-units::si fmt::fmt) +target_link_libraries(confman_example_variant PRIVATE tfc::confman mp-units::systems fmt::fmt) add_executable(confman_example_item confman_example_item.cpp) -target_link_libraries(confman_example_item PRIVATE tfc::confman tfc::ipc mp-units::si fmt::fmt) +target_link_libraries(confman_example_item PRIVATE tfc::confman tfc::ipc mp-units::systems fmt::fmt) add_executable(confman_example_useless useless-box.cpp) -target_link_libraries(confman_example_useless PRIVATE tfc::confman tfc::ipc mp-units::si fmt::fmt) +target_link_libraries(confman_example_useless PRIVATE tfc::confman tfc::ipc mp-units::systems fmt::fmt) diff --git a/libs/confman/testing/examples/array.cpp b/libs/confman/testing/examples/array.cpp index 10f6cdf63a..d8da02907f 100644 --- a/libs/confman/testing/examples/array.cpp +++ b/libs/confman/testing/examples/array.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include @@ -32,7 +32,7 @@ int main(int argc, char** argv) { tfc::confman::config> const config{ dbus, "key" }; fmt::print("Schema is: {}\n", config.schema()); - fmt::print("Config is: {}\n", config.string()); + fmt::print("Config is: {}\n", config.string().value()); dbus->request_name(tfc::dbus::make_dbus_process_name().c_str()); diff --git a/libs/confman/testing/examples/confman_example_item.cpp b/libs/confman/testing/examples/confman_example_item.cpp index c05517bc1f..fa7d0373f8 100644 --- a/libs/confman/testing/examples/confman_example_item.cpp +++ b/libs/confman/testing/examples/confman_example_item.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include @@ -22,11 +22,11 @@ int main(int argc, char** argv) { tfc::confman::config> const config{ dbus, "key" }; config->observe([](auto new_value, auto old_value) { - fmt::print("new value: {}, old value: {}\n", new_value.to_json(), old_value.to_json()); + fmt::print("new value: {}, old value: {}\n", new_value.to_json().value(), old_value.to_json().value()); }); fmt::print("Schema is: {}\n", config.schema()); - fmt::print("Config is: {}\n", config.string()); + fmt::print("Config is: {}\n", config.string().value()); dbus->request_name(tfc::dbus::make_dbus_process_name().c_str()); diff --git a/libs/confman/testing/examples/nested.cpp b/libs/confman/testing/examples/nested.cpp index ae408ae5d4..e13e7b4fb0 100644 --- a/libs/confman/testing/examples/nested.cpp +++ b/libs/confman/testing/examples/nested.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include @@ -75,7 +75,7 @@ int main(int argc, char** argv) { [](auto new_value, auto old_value) { fmt::print("new value: {}, old value: {}\n", new_value, old_value); }); fmt::print("Schema is: {}\n", config.schema()); - fmt::print("Config is: {}\n", config.string()); + fmt::print("Config is: {}\n", config.string().value()); dbus->request_name(tfc::dbus::make_dbus_process_name().c_str()); diff --git a/libs/confman/testing/examples/simple.cpp b/libs/confman/testing/examples/simple.cpp index 979e21de60..d5551d075a 100644 --- a/libs/confman/testing/examples/simple.cpp +++ b/libs/confman/testing/examples/simple.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include @@ -57,7 +57,7 @@ int main(int argc, char** argv) { [](bool new_value, bool old_value) { fmt::print("new value: {}, old value: {}\n", new_value, old_value); }); fmt::print("Schema is: {}\n", config.schema()); - fmt::print("Config is: {}\n", config.string()); + fmt::print("Config is: {}\n", config.string().value()); dbus->request_name(tfc::dbus::make_dbus_process_name().c_str()); diff --git a/libs/confman/testing/examples/useless-box.cpp b/libs/confman/testing/examples/useless-box.cpp index 3e41c99984..ad94c6d76b 100644 --- a/libs/confman/testing/examples/useless-box.cpp +++ b/libs/confman/testing/examples/useless-box.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include @@ -81,7 +81,7 @@ int main(int argc, char** argv) { }); fmt::println("Schema is: {}", config.schema()); - fmt::println("Config is: {}", config.string()); + fmt::println("Config is: {}", config.string().value()); dbus->request_name(tfc::dbus::make_dbus_process_name().c_str()); diff --git a/libs/confman/testing/examples/variant.cpp b/libs/confman/testing/examples/variant.cpp index a0c38539c7..87f3ad7783 100644 --- a/libs/confman/testing/examples/variant.cpp +++ b/libs/confman/testing/examples/variant.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include @@ -65,11 +65,12 @@ int main(int argc, char** argv) { tfc::confman::config>> const config{ dbus, "key" }; config->observe([](auto const& new_value, auto const& old_value) { - fmt::print("new value:\n{}\n\n\nold value:\n{}\n", glz::write_json(new_value), glz::write_json(old_value)); + fmt::print("new value:\n{}\n\n\nold value:\n{}\n", glz::write_json(new_value).value(), + glz::write_json(old_value).value()); }); fmt::print("Schema is: {}\n", config.schema()); - fmt::print("Config is: {}\n", config.string()); + fmt::print("Config is: {}\n", config.string().value()); dbus->request_name(tfc::dbus::make_dbus_process_name().c_str()); diff --git a/libs/confman/testing/stubs/inc/public/tfc/stubs/confman.hpp b/libs/confman/testing/stubs/inc/public/tfc/stubs/confman.hpp index 2248f91cd1..9aa43fc0cc 100644 --- a/libs/confman/testing/stubs/inc/public/tfc/stubs/confman.hpp +++ b/libs/confman/testing/stubs/inc/public/tfc/stubs/confman.hpp @@ -37,8 +37,10 @@ class stub_config : public detail::stubbed_config { auto access() noexcept -> storage_t& { return storage_; } auto operator->() const noexcept -> storage_t const* { return std::addressof(value()); } - [[nodiscard]] auto string() const -> std::string { return glz::write_json(storage_); } - [[nodiscard]] auto schema() const -> std::string { return glz::write_json_schema(); } + [[nodiscard]] auto string() const -> std::expected { return glz::write_json(storage_); } + [[nodiscard]] auto schema() const -> std::expected { + return glz::write_json_schema(); + } auto set_changed() const noexcept -> std::error_code { return {}; } diff --git a/libs/confman/testing/tests/confman_test.cpp b/libs/confman/testing/tests/confman_test.cpp index 8c0539c9bd..e90fa925c5 100644 --- a/libs/confman/testing/tests/confman_test.cpp +++ b/libs/confman/testing/tests/confman_test.cpp @@ -77,7 +77,7 @@ auto main(int argc, char** argv) -> int { .a = observable{ 1 }, .b = observable{ 2 }, .c = observable{ "bar" } } }; auto const json_str{ test.config.string() }; glz::json_t json{}; - std::ignore = glz::read_json(json, json_str); + std::ignore = glz::read_json(json, json_str.value_or("")); ut::expect(static_cast(json["a"].get()) == 1); ut::expect(static_cast(json["b"].get()) == 2); ut::expect(json["c"].get() == "bar"); @@ -102,7 +102,7 @@ auto main(int argc, char** argv) -> int { json["a"] = 11; json["b"] = 22; json["c"] = "meeoow"; - test.config.from_string(glz::write_json(json)); + test.config.from_string(glz::write_json(json).value()); ut::expect(11 == test.config->a); ut::expect(22 == test.config->b); ut::expect("meeoow" == test.config->c); @@ -123,7 +123,7 @@ auto main(int argc, char** argv) -> int { json["a"] = 11; json["b"] = 22; json["c"] = "meeoow"; - test.config.from_string(glz::write_json(json)); + test.config.from_string(glz::write_json(json).value()); ut::expect(1 == c_called); }; diff --git a/libs/confman/testing/tests/observer_test.cpp b/libs/confman/testing/tests/observer_test.cpp index d186974415..c2b210f6db 100644 --- a/libs/confman/testing/tests/observer_test.cpp +++ b/libs/confman/testing/tests/observer_test.cpp @@ -13,8 +13,9 @@ auto main(int, char**) -> int { }; "glaze conversion test"_test = []() { tfc::confman::observable const observed_value(25); - std::string const value_as_str = glz::write_json(observed_value); - expect(value_as_str == "25"); + auto const value_as_str = glz::write_json(observed_value); + expect(fatal(value_as_str.has_value())); + expect(value_as_str.value() == "25"); }; "std containers"_test = []() { std::array array{ 1, 2, 3 }; diff --git a/libs/dbus_util/CMakeLists.txt b/libs/dbus_util/CMakeLists.txt index 0a1adf3d07..5f6b921b24 100644 --- a/libs/dbus_util/CMakeLists.txt +++ b/libs/dbus_util/CMakeLists.txt @@ -1,6 +1,6 @@ project(dbus_util) -add_library(dbus_util src/dbus_util.cpp src/compile_tests.cpp src/exception.cpp src/sml_interface.cpp src/string_maker.cpp) +add_library(dbus_util src/dbus_util.cpp src/exception.cpp src/sml_interface.cpp src/string_maker.cpp) add_library(tfc::dbus_util ALIAS dbus_util) target_include_directories(dbus_util PUBLIC diff --git a/libs/dbus_util/inc/public/tfc/dbus/sdbusplus_meta.hpp b/libs/dbus_util/inc/public/tfc/dbus/sdbusplus_meta.hpp index f656ab67a0..d4c3850511 100644 --- a/libs/dbus_util/inc/public/tfc/dbus/sdbusplus_meta.hpp +++ b/libs/dbus_util/inc/public/tfc/dbus/sdbusplus_meta.hpp @@ -86,7 +86,7 @@ struct append_single { template struct append_single { - static void op(auto* interface, auto* sd_bus_msg, mp_units::Quantity auto&& item) { + static void op(auto* interface, auto* sd_bus_msg, quantity_t const& item) { using value_t = typename quantity_t::rep; append_single::op(interface, sd_bus_msg, item.numerical_value_ref_in(quantity_t::unit)); } diff --git a/libs/dbus_util/tests/CMakeLists.txt b/libs/dbus_util/tests/CMakeLists.txt index 32720ea9c9..c145eec905 100644 --- a/libs/dbus_util/tests/CMakeLists.txt +++ b/libs/dbus_util/tests/CMakeLists.txt @@ -7,6 +7,12 @@ target_link_libraries(dbus_util_test PRIVATE Boost::ut tfc::logger tfc::base tfc add_test(NAME dbus_util_test COMMAND dbus_util_test) +add_executable(dbus_compile_tests dbus_compile_tests.cpp) + +target_link_libraries(dbus_compile_tests PRIVATE Boost::ut tfc::logger tfc::base tfc::dbus_util) + +add_test(NAME dbus_compile_tests COMMAND dbus_compile_tests) + if(BUILD_EXAMPLES) -add_subdirectory(examples) + add_subdirectory(examples) endif() diff --git a/libs/dbus_util/src/compile_tests.cpp b/libs/dbus_util/tests/dbus_compile_tests.cpp similarity index 94% rename from libs/dbus_util/src/compile_tests.cpp rename to libs/dbus_util/tests/dbus_compile_tests.cpp index 3cbf6c265d..e09a931d16 100644 --- a/libs/dbus_util/src/compile_tests.cpp +++ b/libs/dbus_util/tests/dbus_compile_tests.cpp @@ -12,3 +12,7 @@ static_assert("path_namespace='foo'," == path_namespace); static_assert("destination='foo'," == destination); } // namespace tfc::dbus::match::rules::test + +int main() { + return 0; +} diff --git a/libs/dbus_util/tests/dbus_util_test.cpp b/libs/dbus_util/tests/dbus_util_test.cpp index b03b7f7521..90fd0ae422 100644 --- a/libs/dbus_util/tests/dbus_util_test.cpp +++ b/libs/dbus_util/tests/dbus_util_test.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include #include #include diff --git a/libs/ipc/inc/public/tfc/ipc/details/dbus_ipc.hpp b/libs/ipc/inc/public/tfc/ipc/details/dbus_ipc.hpp index af2448ad83..4a82745957 100644 --- a/libs/ipc/inc/public/tfc/ipc/details/dbus_ipc.hpp +++ b/libs/ipc/inc/public/tfc/ipc/details/dbus_ipc.hpp @@ -47,9 +47,7 @@ class dbus_ipc { interface_->register_signal(std::string{ dbus::tags::value }); interface_->register_property_r(std::string{ dbus::tags::value }, sdbusplus::vtable::property_::none, [this](const auto&) { return value_; }); - interface_->register_property_r( - std::string{ dbus::tags::type }, sdbusplus::vtable::property_::const_, - []([[maybe_unused]] std::string& old_value) { return tfc::json::write_json_schema(); }); + interface_->register_property(std::string{ dbus::tags::type }, schema); interface_->initialize(); } @@ -70,6 +68,7 @@ class dbus_ipc { private: std::shared_ptr interface_{}; value_t value_{}; + std::string const schema{ tfc::json::write_json_schema().value() }; }; } // namespace tfc::ipc::details diff --git a/libs/ipc/inc/public/tfc/ipc/details/dbus_server_iface.hpp b/libs/ipc/inc/public/tfc/ipc/details/dbus_server_iface.hpp index ccb605ae4a..f78e91a43b 100644 --- a/libs/ipc/inc/public/tfc/ipc/details/dbus_server_iface.hpp +++ b/libs/ipc/inc/public/tfc/ipc/details/dbus_server_iface.hpp @@ -282,16 +282,34 @@ class ipc_manager_server { }); dbus_interface_->register_property_r( - std::string(consts::signals_property), sdbusplus::vtable::property_::emits_change, - [&](const auto&) { return glz::write_json(ipc_manager_->get_all_signals()); }); + std::string(consts::signals_property), sdbusplus::vtable::property_::emits_change, [&](const auto&) { + auto const write{ glz::write_json(ipc_manager_->get_all_signals()) }; + if (!write) { + fmt::println(stderr, "Failed to write signals to json: {}", format_error(write.error())); + throw dbus_error("Failed to write signals to json"); + } + return write.value(); + }); dbus_interface_->register_property_r( - std::string(consts::slots_property), sdbusplus::vtable::property_::emits_change, - [&](const auto&) { return glz::write_json(ipc_manager_->get_all_slots()); }); + std::string(consts::slots_property), sdbusplus::vtable::property_::emits_change, [&](const auto&) { + auto const write{ glz::write_json(ipc_manager_->get_all_slots()) }; + if (!write) { + fmt::println(stderr, "Failed to write slots to json: {}", format_error(write.error())); + throw dbus_error("Failed to write slots to json"); + } + return write.value(); + }); dbus_interface_->register_property_r( - std::string(consts::connections_property), sdbusplus::vtable::property_::emits_change, - [&](const auto&) { return glz::write_json(ipc_manager_->get_all_connections()); }); + std::string(consts::connections_property), sdbusplus::vtable::property_::emits_change, [&](const auto&) { + auto const write{ glz::write_json(ipc_manager_->get_all_connections()) }; + if (!write) { + fmt::println(stderr, "Failed to write connections to json: {}", format_error(write.error())); + throw dbus_error("Failed to write connections to json"); + } + return write.value(); + }); dbus_interface_->register_signal>(""); dbus_interface_->initialize(); diff --git a/libs/ipc/inc/public/tfc/ipc/details/type_description.hpp b/libs/ipc/inc/public/tfc/ipc/details/type_description.hpp index 1d4fb2c630..ce5f15fced 100644 --- a/libs/ipc/inc/public/tfc/ipc/details/type_description.hpp +++ b/libs/ipc/inc/public/tfc/ipc/details/type_description.hpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include @@ -40,7 +40,7 @@ using type_length = type_description; using pressure_t = std::expected, std::int64_t>, sensor_error_e>; using type_pressure = type_description; -inline constexpr struct celsius : mp_units::named_unit { +inline constexpr struct celsius final : mp_units::named_unit { } celsius; using temperature_t = std::expected, std::int64_t>, sensor_error_e>; using type_temperature = type_description; diff --git a/libs/ipc/inc/public/tfc/ipc/item.hpp b/libs/ipc/inc/public/tfc/ipc/item.hpp index 82facd8fb4..4eabf6cf13 100644 --- a/libs/ipc/inc/public/tfc/ipc/item.hpp +++ b/libs/ipc/inc/public/tfc/ipc/item.hpp @@ -9,9 +9,10 @@ #include #include -#include -#include +#include +#include #include +#include #include @@ -209,8 +210,8 @@ struct item; struct item { using time_point = std::chrono::time_point; - [[nodiscard]] static auto from_json(std::string_view json) -> std::expected; - [[nodiscard]] auto to_json() const -> std::string; + [[nodiscard]] static auto from_json(std::string_view json) -> std::expected; + [[nodiscard]] auto to_json() const -> std::expected; [[nodiscard]] auto id() const -> std::string; // ids diff --git a/libs/ipc/src/item.cpp b/libs/ipc/src/item.cpp index bf20b7f63e..7a5a4bd980 100644 --- a/libs/ipc/src/item.cpp +++ b/libs/ipc/src/item.cpp @@ -27,7 +27,7 @@ auto make() -> item { pcg_extras::seed_seq_from seed_source; return make(seed_source); } -auto item::from_json(std::string_view json) -> std::expected { +auto item::from_json(std::string_view json) -> std::expected { auto temporary = glz::read_json(json); if (!temporary.has_value()) { return temporary; @@ -36,7 +36,7 @@ auto item::from_json(std::string_view json) -> std::expectedlast_exchange = std::chrono::time_point_cast(std::chrono::system_clock::now()); return temporary; } -auto item::to_json() const -> std::string { +auto item::to_json() const -> std::expected { return glz::write_json(*this); } auto item::id() const -> std::string { diff --git a/libs/ipc/testing/examples/CMakeLists.txt b/libs/ipc/testing/examples/CMakeLists.txt index 0d22f7cf56..a454bee7e2 100644 --- a/libs/ipc/testing/examples/CMakeLists.txt +++ b/libs/ipc/testing/examples/CMakeLists.txt @@ -10,4 +10,4 @@ target_link_libraries(ipc_coroutines PRIVATE tfc::base tfc::ipc) find_package(mp-units CONFIG REQUIRED) find_package(fmt CONFIG REQUIRED) tfc_add_example_no_test(mass_example mass_example.cpp) -target_link_libraries(mass_example PRIVATE tfc::base tfc::ipc mp-units::si fmt::fmt) +target_link_libraries(mass_example PRIVATE tfc::base tfc::ipc mp-units::systems fmt::fmt) diff --git a/libs/ipc/testing/examples/mass_example.cpp b/libs/ipc/testing/examples/mass_example.cpp index 84cd435dba..0d96723d93 100644 --- a/libs/ipc/testing/examples/mass_example.cpp +++ b/libs/ipc/testing/examples/mass_example.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include namespace asio = boost::asio; diff --git a/libs/ipc/testing/tests/filter_test.cpp b/libs/ipc/testing/tests/filter_test.cpp index f95d018e4a..330d0daef2 100644 --- a/libs/ipc/testing/tests/filter_test.cpp +++ b/libs/ipc/testing/tests/filter_test.cpp @@ -163,11 +163,11 @@ auto main(int, char**) -> int { // down to business, 1 second has elapsed since the timer was made with 42 second wait, let's change the value expect(!finished); - auto generic_config = glz::read_json(config.string()); + auto generic_config = glz::read_json(config.string().value_or("")); expect(generic_config.has_value() >> fatal); generic_config->at("time_on") = 1; // reduce time to only 1 millisecond generic_config->at("time_off") = 1; - expect(!config.from_string(glz::write_json(generic_config.value())) >> fatal); + expect(!config.from_string(glz::write_json(generic_config.value()).value_or("")) >> fatal); expect(config->time_on == 1ms); expect(config->time_off == 1ms); diff --git a/libs/ipc/testing/tests/item_test.cpp b/libs/ipc/testing/tests/item_test.cpp index fa16970b49..0f10a9930d 100644 --- a/libs/ipc/testing/tests/item_test.cpp +++ b/libs/ipc/testing/tests/item_test.cpp @@ -38,7 +38,7 @@ auto main(int, char**) -> int { "json"_test = [] { auto item = item::make(); - auto remake = item::item::from_json(item.to_json()).value(); + auto remake = item::item::from_json(item.to_json().value_or("")).value(); // last_exchange is updated when from_json is called expect(item.id() == remake.id()); expect(item.entry_timestamp == remake.entry_timestamp); diff --git a/libs/motor/inc/public/tfc/motor.hpp b/libs/motor/inc/public/tfc/motor.hpp index 55df6bfe2d..2bfa9114fd 100644 --- a/libs/motor/inc/public/tfc/motor.hpp +++ b/libs/motor/inc/public/tfc/motor.hpp @@ -3,8 +3,8 @@ #include #include -#include -#include +#include +#include #include #include diff --git a/libs/motor/inc/public/tfc/motor/atv320motor.hpp b/libs/motor/inc/public/tfc/motor/atv320motor.hpp index 4bc1b4084d..8b78a13e60 100644 --- a/libs/motor/inc/public/tfc/motor/atv320motor.hpp +++ b/libs/motor/inc/public/tfc/motor/atv320motor.hpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include diff --git a/libs/motor/inc/public/tfc/motor/dbus_tags.hpp b/libs/motor/inc/public/tfc/motor/dbus_tags.hpp index 47e30ad069..2a96451c64 100644 --- a/libs/motor/inc/public/tfc/motor/dbus_tags.hpp +++ b/libs/motor/inc/public/tfc/motor/dbus_tags.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include diff --git a/libs/motor/inc/public/tfc/motor/details/positioner_impl.hpp b/libs/motor/inc/public/tfc/motor/details/positioner_impl.hpp index 4284667f5c..ca8d7321f2 100644 --- a/libs/motor/inc/public/tfc/motor/details/positioner_impl.hpp +++ b/libs/motor/inc/public/tfc/motor/details/positioner_impl.hpp @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include #include diff --git a/libs/motor/inc/public/tfc/motor/positioner.hpp b/libs/motor/inc/public/tfc/motor/positioner.hpp index 0e018a1688..8742aa2df2 100644 --- a/libs/motor/inc/public/tfc/motor/positioner.hpp +++ b/libs/motor/inc/public/tfc/motor/positioner.hpp @@ -15,10 +15,10 @@ #include #include -#include #include -#include -#include +#include +#include +#include #include #include diff --git a/libs/motor/inc/public/tfc/motor/stub.hpp b/libs/motor/inc/public/tfc/motor/stub.hpp index 7e15f771aa..814a9d8666 100644 --- a/libs/motor/inc/public/tfc/motor/stub.hpp +++ b/libs/motor/inc/public/tfc/motor/stub.hpp @@ -3,8 +3,8 @@ #include #include -#include -#include +#include +#include #include #include diff --git a/libs/motor/inc/public/tfc/motor/virtual_motor.hpp b/libs/motor/inc/public/tfc/motor/virtual_motor.hpp index 5b9542959d..a78d89b6f2 100644 --- a/libs/motor/inc/public/tfc/motor/virtual_motor.hpp +++ b/libs/motor/inc/public/tfc/motor/virtual_motor.hpp @@ -1,12 +1,12 @@ #pragma once #include -#include #include #include #include -#include -#include +#include +#include +#include #include #include diff --git a/libs/motor/testing/mocks/inc/public/tfc/mocks/motor.hpp b/libs/motor/testing/mocks/inc/public/tfc/mocks/motor.hpp index d43f28caaa..a52c0fe690 100644 --- a/libs/motor/testing/mocks/inc/public/tfc/mocks/motor.hpp +++ b/libs/motor/testing/mocks/inc/public/tfc/mocks/motor.hpp @@ -5,8 +5,8 @@ #include #include -#include -#include +#include +#include #include diff --git a/libs/motor/testing/tests/motor_impl_test.cxx b/libs/motor/testing/tests/motor_impl_test.cxx index 83796b0abf..38c0fbee7a 100644 --- a/libs/motor/testing/tests/motor_impl_test.cxx +++ b/libs/motor/testing/tests/motor_impl_test.cxx @@ -1,6 +1,6 @@ #include -#include -#include +#include +#include #include #include diff --git a/libs/motor/testing/tests/stub_test.cpp b/libs/motor/testing/tests/stub_test.cpp index 02f59f0434..6963804a66 100644 --- a/libs/motor/testing/tests/stub_test.cpp +++ b/libs/motor/testing/tests/stub_test.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include diff --git a/libs/snitch/inc/public/tfc/snitch/format_extension.hpp b/libs/snitch/inc/public/tfc/snitch/format_extension.hpp index df860e575e..0a5e61a16c 100644 --- a/libs/snitch/inc/public/tfc/snitch/format_extension.hpp +++ b/libs/snitch/inc/public/tfc/snitch/format_extension.hpp @@ -24,21 +24,22 @@ struct custom_handler { return 0; } - constexpr auto on_arg_id(fmt::basic_string_view id) -> int { + constexpr auto on_arg_id([[maybe_unused]] fmt::basic_string_view id) -> int { num_args++; return 0; } - constexpr void on_replacement_field(int id, const Char* begin) { + constexpr void on_replacement_field([[maybe_unused]] int id, [[maybe_unused]] const Char* begin) { // todo } - constexpr auto on_format_specs(int id, const Char* begin, const Char*) -> const Char* { + constexpr auto on_format_specs([[maybe_unused]] int id, [[maybe_unused]] const Char* begin, [[maybe_unused]] const Char*) + -> const Char* { has_format_specs = true; return begin; } - constexpr void on_error(const char* msg) { + constexpr void on_error([[maybe_unused]] const char* msg) { // todo } std::size_t num_args{ 0 }; @@ -82,13 +83,13 @@ struct custom_handler_names { return 0; } - constexpr void on_replacement_field(int id, const Char* begin) { + constexpr void on_replacement_field([[maybe_unused]] int id, [[maybe_unused]] const Char* begin) { // todo } - constexpr auto on_format_specs(int id, const Char* begin, const Char*) -> const Char* { return begin; } + constexpr auto on_format_specs([[maybe_unused]] int id, const Char* begin, const Char*) -> const Char* { return begin; } - constexpr void on_error(const char* msg) { + constexpr void on_error([[maybe_unused]] const char* msg) { // todo } std::array, N> names{}; diff --git a/libs/snitch/tests/snitch_test.cpp b/libs/snitch/tests/snitch_test.cpp index f34eea4d4e..f984eac661 100644 --- a/libs/snitch/tests/snitch_test.cpp +++ b/libs/snitch/tests/snitch_test.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -34,14 +35,18 @@ static_assert(arg_names>()[0] == "name"); static_assert(arg_names>()[0] == "name"); static_assert(arg_names>().size() == 1); +#if __clang__ // todo this below shouldn't be possible, remove if causes compile error and celebrate static_assert(arg_names>()[1] == ""); +#endif static_assert(arg_names>()[1] == "bar"); static_assert(arg_names>().size() == 2); } // namespace test +static_assert(glz::detail::count_members == 10); + auto main(int argc, char** argv) -> int { using boost::ut::operator""_test; using boost::ut::expect; diff --git a/libs/stx/CMakeLists.txt b/libs/stx/CMakeLists.txt index b13d871e9a..28d2ed7c58 100644 --- a/libs/stx/CMakeLists.txt +++ b/libs/stx/CMakeLists.txt @@ -16,7 +16,7 @@ find_package(stduuid CONFIG REQUIRED) target_link_libraries(stx INTERFACE glaze::glaze - mp-units::si + mp-units::systems mp-units::core stduuid ) diff --git a/libs/stx/inc/public/tfc/stx/concepts.hpp b/libs/stx/inc/public/tfc/stx/concepts.hpp index 4c04bb671b..8b4e2c38c8 100644 --- a/libs/stx/inc/public/tfc/stx/concepts.hpp +++ b/libs/stx/inc/public/tfc/stx/concepts.hpp @@ -7,7 +7,7 @@ #include // todo should we split to concepts/mp-units.hpp and concepts/stx.hpp -#include +#include namespace tfc::stx { diff --git a/libs/stx/inc/public/tfc/stx/to_string_view.hpp b/libs/stx/inc/public/tfc/stx/to_string_view.hpp index 934b8fab0c..91d087d530 100644 --- a/libs/stx/inc/public/tfc/stx/to_string_view.hpp +++ b/libs/stx/inc/public/tfc/stx/to_string_view.hpp @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include #include #include diff --git a/libs/stx/inc/public/tfc/utils/units_glaze_meta.hpp b/libs/stx/inc/public/tfc/utils/units_glaze_meta.hpp index 27568e197e..7c44fed327 100644 --- a/libs/stx/inc/public/tfc/utils/units_glaze_meta.hpp +++ b/libs/stx/inc/public/tfc/utils/units_glaze_meta.hpp @@ -5,10 +5,10 @@ #include #include -#include -#include +#include +#include +#include #include -#include #include #include @@ -19,8 +19,10 @@ consteval auto dimension_name() -> std::string_view; } // namespace tfc::unit template <> -struct glz::meta { - static constexpr auto value{ glz::object("numerator", &mp_units::ratio::num, "denominator", &mp_units::ratio::den) }; +struct glz::meta { + static constexpr auto value{ + glz::object("numerator", &mp_units::detail::ratio::num, "denominator", &mp_units::detail::ratio::den) + }; static constexpr auto name{ "units::ratio" }; }; @@ -85,7 +87,7 @@ struct to_json_schema> { static_assert(unit_ascii[0] != 0); static_assert(unit_unicode[0] != 0); - static constexpr mp_units::ratio ratio{ mp_units::as_ratio(ref_t) }; + static constexpr mp_units::detail::ratio ratio{ mp_units::detail::as_ratio(ref_t) }; static constexpr auto dimension{ tfc::unit::dimension_name() }; template static void op([[maybe_unused]] auto& schema, [[maybe_unused]] auto& defs) { @@ -99,7 +101,7 @@ struct to_json_schema> { // if constexpr (mp_units::Magnitude) { // data->ratio = tfc::json::schema_meta::ratio_impl{ .numerator = ratio.num, .denominator = ratio.den }; // } - // to_json_schema::template op(schema, defs); + to_json_schema::template op(schema, defs); } }; diff --git a/libs/stx/tests/CMakeLists.txt b/libs/stx/tests/CMakeLists.txt index 5cc58f65ba..3d9c3697bf 100644 --- a/libs/stx/tests/CMakeLists.txt +++ b/libs/stx/tests/CMakeLists.txt @@ -11,8 +11,7 @@ target_link_libraries(test_glaze_meta Boost::ut glaze::glaze fmt::fmt - mp-units::si - mp-units::isq + mp-units::systems mp-units::core ) diff --git a/libs/stx/tests/test_glaze_meta.cpp b/libs/stx/tests/test_glaze_meta.cpp index 6455fd1a27..129e869888 100644 --- a/libs/stx/tests/test_glaze_meta.cpp +++ b/libs/stx/tests/test_glaze_meta.cpp @@ -1,7 +1,7 @@ #include -#include -#include -#include +#include +#include +#include #include #include @@ -42,25 +42,28 @@ auto main() -> int { "chrono"_test = [] { using test_t = std::chrono::duration; test_t foo{ std::chrono::seconds(32) }; - std::string const json{ glz::write_json(foo) }; - ut::expect(json == "320") << "got: " << json; - ut::expect(glz::read_json(json).value() == foo); + auto const json{ glz::write_json(foo) }; + ut::expect(fatal(json.has_value())); + ut::expect(json == "320") << "got: " << json.value(); + ut::expect(glz::read_json(json.value()).value() == foo); }; "mp"_test = [] { using namespace mp_units::si::unit_symbols; auto foo{ 42 * (km / h) }; - std::string const json{ glz::write_json(foo) }; - ut::expect(json == "42") << "got: " << json; - [[maybe_unused]] auto bar = glz::read_json(json); + auto const json{ glz::write_json(foo) }; + ut::expect(fatal(json.has_value())); + ut::expect(json.value() == "42") << "got: " << json.value(); + [[maybe_unused]] auto bar = glz::read_json(json.value()); if (!bar.has_value()) { - fmt::print("{}\n", glz::format_error(bar.error(), json)); + fmt::print("{}\n", glz::format_error(bar.error(), json.value())); } - ut::expect(glz::read_json(json).has_value()); + ut::expect(glz::read_json(json.value()).has_value()); }; "fixed_string_to_json"_test = [] { tfc::stx::basic_fixed_string foo{ "HelloWorld" }; auto foo_json{ glz::write_json(foo) }; - ut::expect(foo_json == R"("HelloWorld")") << glz::write_json(foo); + ut::expect(fatal(foo_json.has_value())); + ut::expect(foo_json.value() == R"("HelloWorld")") << glz::write_json(foo).value(); }; "fixed_string_from_json"_test = [] { auto foo = glz::read_json>("\"Hello\""); @@ -82,7 +85,8 @@ auto main() -> int { "millisecond clock"_test = [] { auto now{ std::chrono::time_point_cast(std::chrono::system_clock::now()) }; auto json{ glz::write_json(now) }; - ut::expect(glz::read_json(json).value() == now); + ut::expect(fatal(json.has_value())); + ut::expect(glz::read_json(json.value()).value() == now); }; return EXIT_SUCCESS; } diff --git a/vcpkg-configuration.json b/vcpkg-configuration.json index 34f2f3a412..5bf4d929cb 100644 --- a/vcpkg-configuration.json +++ b/vcpkg-configuration.json @@ -1,8 +1,8 @@ { "default-registry": { "kind": "git", - "repository": "https://github.com/microsoft/vcpkg", - "baseline": "37630acd98740f98d7e51da78c9630758835037f" + "repository": "https://github.com/jbbjarnason/vcpkg", + "baseline": "af4022b2a83149c292dd74b9b3ebc8c0f90ee595" }, "registries": [ { diff --git a/vcpkg.json b/vcpkg.json index 97e0ab2d40..f1292d0f34 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -14,6 +14,7 @@ "boost-coroutine2", "boost-stacktrace", "boost-uuid", + "boost-hana", "date", "fmt", {