diff --git a/example/client/flex/main.cpp b/example/client/flex/main.cpp index 97f155d..8dbbfb1 100644 --- a/example/client/flex/main.cpp +++ b/example/client/flex/main.cpp @@ -9,20 +9,23 @@ #include #include +#include #include #include #include #include #include -#include #include +#include #if defined(BOOST_ASIO_HAS_CO_AWAIT) #include +#include namespace asio = boost::asio; +namespace buffers = boost::buffers; namespace core = boost::core; namespace http_io = boost::http_io; namespace http_proto = boost::http_proto; @@ -36,6 +39,90 @@ inline const bool http_proto_has_zlib = true; inline const bool http_proto_has_zlib = false; #endif +core::string_view +mime_type(core::string_view path) noexcept +{ + const auto ext = [&path] + { + const auto pos = path.rfind("."); + if(pos == core::string_view::npos) + return core::string_view{}; + return path.substr(pos); + }(); + + using ci_equal = urls::grammar::ci_equal; + if(ci_equal{}(ext, ".gif")) return "image/gif"; + if(ci_equal{}(ext, ".jpg")) return "image/jpeg"; + if(ci_equal{}(ext, ".jpeg")) return "image/jpeg"; + if(ci_equal{}(ext, ".png")) return "image/png"; + if(ci_equal{}(ext, ".svg")) return "image/svg+xml"; + if(ci_equal{}(ext, ".txt")) return "text/plain"; + if(ci_equal{}(ext, ".htm")) return "text/html"; + if(ci_equal{}(ext, ".html")) return "text/html"; + if(ci_equal{}(ext, ".pdf")) return "application/pdf"; + if(ci_equal{}(ext, ".xml")) return "application/xml"; + return "application/octet-stream"; +} + +core::string_view +filename(core::string_view path) noexcept +{ + const auto pos = path.find_last_of("/\\"); + if((pos != std::string_view::npos)) + return path.substr(pos + 1); + return path; +} + +std::uint64_t +filesize(core::string_view path) +{ + http_proto::file file; + boost::system::error_code ec; + file.open( + std::string{ path }.c_str(), + http_proto::file_mode::scan, + ec); + auto size = file.size(ec); + if(ec) + throw boost::system::system_error{ ec }; + return size; +} + +core::string_view +target(urls::url_view url) noexcept +{ + if(url.encoded_target().empty()) + return "/"; + + return url.encoded_target(); +} + +struct is_redirect_result +{ + bool is_redirect; + bool need_method_change; +}; + +is_redirect_result +is_redirect(http_proto::status status) noexcept +{ + // The specifications do not intend for 301 and 302 + // redirects to change the HTTP method, but most + // user agents do change the method in practice. + switch(status) + { + case http_proto::status::moved_permanently: + case http_proto::status::found: + case http_proto::status::see_other: + return { true, true }; + case http_proto::status::temporary_redirect: + case http_proto::status::permanent_redirect: + return { true, false }; + default: + return { false, false }; + } +} + class any_stream { public: @@ -163,31 +250,428 @@ class any_stream class output_stream { - std::ofstream file_; + http_proto::file file_; public: output_stream() = default; explicit output_stream(core::string_view path) { - file_.exceptions(std::ofstream::badbit); - file_.open(path, std::ios::binary); - if(!file_.is_open()) - throw std::runtime_error{ "Couldn't open the output file" }; + boost::system::error_code ec; + file_.open( + std::string{ path }.c_str(), http_proto::file_mode::write, ec); + if(ec) + throw boost::system::system_error{ ec }; } void - write(auto buf) + write(core::string_view str) { if(file_.is_open()) { - file_.write(static_cast(buf.data()), buf.size()); + boost::system::error_code ec; + file_.write(str.data(), str.size(), ec); + if(ec) + throw boost::system::system_error{ ec }; return; } - std::cout.write(static_cast(buf.data()), buf.size()); + std::cout.write(str.data(), str.size()); + } +}; + +class urlencoded_form +{ + std::string body_; + +public: + class source; + void + append_text( + core::string_view name, + core::string_view value) noexcept + { + if(!body_.empty()) + body_ += '&'; + body_ += name; + if(!value.empty()) + body_ += '='; + append_encoded(value); + } + + void + append_file(core::string_view path) + { + http_proto::file file; + boost::system::error_code ec; + + file.open( + std::string{ path }.c_str(), http_proto::file_mode::read, ec); + if(ec) + throw boost::system::system_error{ ec }; + + if(!body_.empty()) + body_ += '&'; + + for(;;) + { + char buf[64 * 1024]; + auto read = file.read(buf, sizeof(buf), ec); + if(ec) + throw boost::system::system_error{ ec }; + if(read == 0) + break; + append_encoded({ buf, read }); + } + } + + core::string_view + content_type() const noexcept + { + return "application/x-www-form-urlencoded"; + } + + std::size_t + content_length() const noexcept + { + return body_.size(); + } + + buffers::const_buffer + body() const noexcept + { + return { body_.data(), body_.size() }; + } + +private: + void + append_encoded(core::string_view str) + { + urls::encoding_opts opt; + opt.space_as_plus = true; + urls::encode( + str, urls::pchars, opt, urls::string_token::append_to(body_)); + } +}; + +class multipart_form +{ + struct part_t + { + core::string_view name; + core::string_view value_or_path; + core::string_view content_type; + std::optional file_size; + }; + + // storage_ containts boundary with extra "--" prefix and postfix. + // This reduces the number of steps needed during serialization. + std::array storage_{ generate_boundary() }; + std::vector parts_; + + static constexpr core::string_view content_disposition_ = + "\r\nContent-Disposition: form-data; name=\""; + static constexpr core::string_view filename_ = + "; filename=\""; + static constexpr core::string_view content_type_ = + "\r\nContent-Type: "; + +public: + class source; + + void + append_text( + core::string_view name, + core::string_view value, + core::string_view content_type) + { + parts_.emplace_back(name, value, content_type ); + } + + void + append_file( + core::string_view name, + core::string_view path, + core::string_view content_type) + { + // store size because file may change on disk between call to + // content_length and serialization. + parts_.emplace_back( + name, path, content_type, filesize(path)); + } + + std::string + content_type() const noexcept + { + std::string res = "multipart/form-data; boundary="; + res.append(storage_.begin() + 2, storage_.end() - 2); // boundary + return res; + } + + std::uint64_t + content_length() const noexcept + { + auto rs = std::uint64_t{}; + for(const auto& part : parts_) + { + rs += storage_.size() - 2; // --boundary + rs += content_disposition_.size(); + rs += part.name.size(); + rs += 1; // Closing double quote + + if(!part.content_type.empty()) + { + rs += content_type_.size(); + rs += part.content_type.size(); + } + + if(part.file_size.has_value()) // file + { + rs += filename_.size(); + rs += filename(part.value_or_path).size(); + rs += 1; // Closing double quote + rs += part.file_size.value(); + } + else // text + { + rs += part.value_or_path.size(); + } + + rs += 4; // after header + rs += 2; // after content + } + rs += storage_.size(); // --boundary-- + return rs; + } + +private: + static + decltype(storage_) + generate_boundary() + { + decltype(storage_) rs; + constexpr static char chars[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + static std::random_device rd; + std::uniform_int_distribution dist{ 0, sizeof(chars) - 2 }; + std::fill(rs.begin(), rs.end(), '-'); + std::generate( + rs.begin() + 2 + 24, + rs.end() - 2, + [&] { return chars[dist(rd)]; }); + return rs; } }; +class multipart_form::source + : public http_proto::source +{ + const multipart_form* form_; + std::vector::const_iterator it_{ form_->parts_.begin() }; + int step_ = 0; + std::uint64_t skip_ = 0; + +public: + explicit source(const multipart_form* form) noexcept + : form_{ form } + { + } + + results + on_read(buffers::mutable_buffer mb) override + { + auto rs = results{}; + + auto copy = [&](core::string_view str) + { + auto copied = buffers::buffer_copy( + mb, + buffers::sans_prefix( + buffers::const_buffer{ str.data(), str.size() }, + static_cast(skip_))); + + mb = buffers::sans_prefix(mb, copied); + rs.bytes += copied; + skip_ += copied; + + if(skip_ != str.size()) + return false; + + skip_ = 0; + return true; + }; + + auto read = [&](core::string_view path, uint64_t size) + { + http_proto::file file; + + file.open( + std::string{ path }.c_str(), + http_proto::file_mode::read, + rs.ec); + if(rs.ec) + return false; + + file.seek(skip_, rs.ec); + if(rs.ec) + return false; + + auto read = file.read( + mb.data(), + (std::min)(static_cast< + std::uint64_t>(mb.size()), size), + rs.ec); + if(rs.ec) + return false; + + mb = buffers::sans_prefix(mb, read); + rs.bytes += read; + skip_ += read; + + if(skip_ != size) + return false; + + skip_ = 0; + return true; + }; + + while(it_ != form_->parts_.end()) + { + switch(step_) + { + case 0: + // --boundary + if(!copy({ form_->storage_.begin(), + form_->storage_.size() - 2 })) return rs; + ++step_; + case 1: + if(!copy(content_disposition_)) return rs; + ++step_; + case 2: + if(!copy(it_->name)) return rs; + ++step_; + case 3: + if(!copy("\"")) return rs; + ++step_; + case 4: + if(!it_->file_size.has_value()) + goto content_type; + if(!copy(filename_)) return rs; + ++step_; + case 5: + if(!copy(filename(it_->value_or_path))) return rs; + ++step_; + case 6: + if(!copy("\"")) return rs; + ++step_; + case 7: + content_type: + if(it_->content_type.empty()) + goto end_of_header; + if(!copy(content_type_)) return rs; + ++step_; + case 8: + if(!copy(it_->content_type)) return rs; + ++step_; + case 9: + end_of_header: + if(!copy("\r\n\r\n")) return rs; + ++step_; + case 10: + if(it_->file_size) + { + if(!read( + it_->value_or_path, + it_->file_size.value())) return rs; + } + else + { + if(!copy(it_->value_or_path)) return rs; + } + ++step_; + case 11: + if(!copy("\r\n")) + return rs; + step_ = 0; + ++it_; + } + } + + // --boundary-- + if(!copy({ form_->storage_.begin(), + form_->storage_.size() })) return rs; + + rs.finished = true; + return rs; + }; +}; + +class message +{ + std::variant< + std::monostate, + urlencoded_form, + multipart_form> body_; +public: + message() = default; + + message(urlencoded_form&& form) + : body_{ std::move(form) } + { + } + + message(multipart_form&& form) + : body_{ std::move(form) } + { + } + + void + set_headers(http_proto::request& req) const + { + std::visit( + [&](auto& f) + { + if constexpr(!std::is_same_v< + decltype(f), const std::monostate&>) + { + req.set_method(http_proto::method::post); + req.set_content_length(f.content_length()); + req.set( + http_proto::field::content_type, + f.content_type()); + } + }, + body_); + } + + void + start_serializer( + http_proto::serializer& ser, + http_proto::request& req) const + { + std::visit( + [&](auto& f) + { + if constexpr(std::is_same_v< + decltype(f), const multipart_form&>) + { + ser.start< + multipart_form::source>(req, &f); + } + else if constexpr(std::is_same_v< + decltype(f), const urlencoded_form&>) + { + ser.start(req, f.body()); + } + else + { + ser.start(req); + } + }, + body_); + } +}; + + asio::awaitable connect(ssl::context& ssl_ctx, urls::url_view url) { @@ -218,42 +702,11 @@ connect(ssl::context& ssl_ctx, urls::url_view url) co_return stream; } -auto -is_redirect(http_proto::status status) noexcept -{ - struct result_t - { - bool is_redirect; - bool need_method_change; - }; - - // The specifications do not intend for 301 and 302 redirects to change the - // HTTP method, but most user agents do change the method in practice. - switch(status) - { - case http_proto::status::moved_permanently: - case http_proto::status::found: - case http_proto::status::see_other: - return result_t{ true, true }; - case http_proto::status::temporary_redirect: - case http_proto::status::permanent_redirect: - return result_t{ true, false }; - default: - return result_t{ false, false }; - } -} - -core::string_view -get_target(urls::url_view url) noexcept -{ - if(url.encoded_target().empty()) - return "/"; - - return url.encoded_target(); -} - http_proto::request -create_request(const po::variables_map& vm, urls::url_view url) +create_request( + const po::variables_map& vm, + const message& msg, + urls::url_view url) { using http_proto::field; using http_proto::method; @@ -269,10 +722,12 @@ create_request(const po::variables_map& vm, urls::url_view url) request.set_version( vm.count("http1.0") ? version::http_1_0 : version::http_1_1); - request.set_target(get_target(url)); + request.set_target(target(url)); request.set(field::accept, "*/*"); request.set(field::host, url.host()); + msg.set_headers(request); + if(vm.count("continue-at")) { auto value = "bytes=" + @@ -321,6 +776,7 @@ asio::awaitable request( const po::variables_map& vm, output_stream& output, + message& msg, ssl::context& ssl_ctx, http_proto::context& http_proto_ctx, http_proto::request request, @@ -330,7 +786,7 @@ request( auto parser = http_proto::response_parser{ http_proto_ctx }; auto serializer = http_proto::serializer{ http_proto_ctx }; - serializer.start(request); + msg.start_serializer(serializer, request); co_await http_io::async_write(stream, serializer); parser.reset(); @@ -361,16 +817,18 @@ request( if(need_method_change && !vm.count("head")) { request.set_method(http_proto::method::get); - // TODO: drop the request body + request.set_content_length(0); + request.erase(http_proto::field::content_type); + msg = {}; // drop the body } - request.set_target(get_target(redirect_url)); + request.set_target(target(redirect_url)); request.set(http_proto::field::host, redirect_url.host()); request.set(http_proto::field::referer, referer_url); referer_url = redirect_url; serializer.reset(); - serializer.start(request); + msg.start_serializer(serializer, request); co_await http_io::async_write(stream, serializer); parser.reset(); @@ -394,7 +852,8 @@ request( { for(auto cb : parser.pull_body()) { - output.write(cb); + output.write( + { static_cast(cb.data()), cb.size() }); parser.consume_body(cb.size()); } @@ -417,43 +876,51 @@ request( int main(int argc, char* argv[]) { + int co_main(int argc, char* argv[]); + //return co_main(argc, argv); try { auto odesc = po::options_description{"Options"}; odesc.add_options() - ("help,h", "produce help message") + ("compressed", "Request compressed response") + ("continue-at,C", + po::value()->value_name(""), + "Resume transfer offset") + ("data,d", + po::value>()->value_name(""), + "HTTP POST data") + ("form,F", + po::value>()->value_name(""), + "Specify multipart MIME data") ("head,I", "Show document info only") ("header,H", po::value>()->value_name("
"), "Pass custom header(s) to server") + ("help,h", "produce help message") + ("http1.0", "Use HTTP 1.0") ("location,L", "Follow redirects") - ("continue-at,C", - po::value()->value_name(""), - "Resume transfer offset") - ("range,r", - po::value()->value_name(""), - "Retrieve only the bytes within range") ("output,o", po::value()->value_name(""), "Write to file instead of stdout") + ("range,r", + po::value()->value_name(""), + "Retrieve only the bytes within range") + ("referer,e", + po::value()->value_name(""), + "Referer URL") ("request,X", po::value()->value_name(""), "Specify request method to use") ("show-headers,i", "Show response headers in the output") - ("referer,e", + ("url", po::value()->value_name(""), - "Referer URL") + "URL to work with") ("user,u", po::value()->value_name(""), "Server user and password") ("user-agent,A", po::value()->value_name(""), - "Send User-Agent to server") - ("url", - po::value()->value_name(""), - "URL to work with") - ("compressed", "Request compressed response") - ("http1.0", "Use HTTP 1.0"); + "Send User-Agent to server"); auto podesc = po::positional_options_description{}; podesc.add("url", 1); @@ -470,10 +937,11 @@ main(int argc, char* argv[]) if(vm.count("help") || !vm.count("url")) { std::cerr - << "Usage: flex_await [options...] \n" + << "Usage: flex [options...] \n" << "Example:\n" - << " flex_await https://www.example.com\n" - << " flex_await -L http://httpstat.us/301\n" + << " flex https://www.example.com\n" + << " flex -L http://httpstat.us/301\n" + << " flex https://httpbin.org/post -F name=Shadi -F img=@./avatar.jpeg\n" << odesc; return EXIT_FAILURE; } @@ -513,14 +981,78 @@ main(int argc, char* argv[]) return output_stream{}; }(); + auto msg = message{}; + + if(vm.count("form") && vm.count("data")) + throw std::runtime_error{ + "You can only select one HTTP request method"}; + + if(vm.count("form")) + { + auto form = multipart_form{}; + for(auto& data : vm.at("form").as>()) + { + if(auto pos = data.find('='); pos != std::string::npos) + { + auto name = core::string_view{ data }.substr(0, pos); + auto value = core::string_view{ data }.substr(pos + 1); + if(!value.empty() && value[0] == '@') + { + form.append_file( + name, + value.substr(1), + mime_type(value.substr(1))); + } + else + { + form.append_text(name, value, ""); + } + } + else + { + throw std::runtime_error{ + "Illegally formatted input field"}; + } + } + msg = std::move(form); + } + + if(vm.count("data")) + { + auto form = urlencoded_form{}; + for(auto& data : vm.at("data").as>()) + { + if(!data.empty() && data[0] == '@') + { + form.append_file(data.substr(1)); + } + else + { + if(auto pos = data.find('='); + pos != std::string::npos) + { + form.append_text( + data.substr(0, pos), + data.substr(pos + 1)); + } + else + { + form.append_text(data, ""); + } + } + } + msg = std::move(form); + } + asio::co_spawn( ioc, request( vm, output, + msg, ssl_ctx, http_proto_ctx, - create_request(vm, url.value()), + create_request(vm, msg, url.value()), url.value()), [](std::exception_ptr ep) {