diff --git a/src/lib/tls/asio/asio_async_ops.h b/src/lib/tls/asio/asio_async_ops.h index a6350b41959..7ab5bad9f62 100644 --- a/src/lib/tls/asio/asio_async_ops.h +++ b/src/lib/tls/asio/asio_async_ops.h @@ -19,6 +19,7 @@ // which interferes with Botan's amalgamation by defining macros like 'B0' and 'FF1'. #define BOOST_ASIO_DISABLE_SERIAL_PORT #include + #include #include namespace Botan::TLS::detail { @@ -270,27 +271,32 @@ class AsyncWriteOperation : public AsyncBase> -class AsyncHandshakeOperation : public AsyncBase { - public: - /** - * Construct and invoke an AsyncHandshakeOperation. - * - * @param handler Handler function to be called upon completion. - * @param stream The stream from which the data will be read - * @param ec Optional error code; used to report an error to the handler function. - */ - template - AsyncHandshakeOperation(HandlerT&& handler, Stream& stream, const boost::system::error_code& ec = {}) : - AsyncBase(std::forward(handler), - stream.get_executor()), - m_stream(stream) { - this->operator()(ec, std::size_t(0), false); +template +boost::asio::awaitable> async_write_some_awaitable(Stream& stream) { + size_t sent = 0; + while(stream.has_data_to_send()) { + // If we have data to be sent to the peer, we do that now. Note that + // this might either be a flight in our handshake, or a TLS alert + // record if we decided to abort due to some failure. + boost::system::error_code ec; + auto written = co_await stream.next_layer().async_write_some( + stream.send_buffer(), boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + stream.consume_send_buffer(written); + sent += written; + + if(ec) { + if(ec == boost::asio::error::eof && !stream.shutdown_received()) { + // transport layer was closed by peer without receiving 'close_notify' + ec = StreamError::StreamTruncated; + } + co_return std::make_pair(sent, ec); } + } - AsyncHandshakeOperation(AsyncHandshakeOperation&&) = default; + co_return std::make_pair(sent, boost::system::error_code{}); +} - /** +/** * Perform a TLS handshake with the peer. * * Depending on the situation, this handler will: @@ -302,71 +308,80 @@ class AsyncHandshakeOperation : public AsyncBase +boost::asio::awaitable async_handshake_awaitable(Stream& stream) { + boost::system::error_code ec; + while(!ec && !stream.native_handle()->is_handshake_complete()) { + if(stream.has_data_to_send()) { + // If we have data to be sent to the peer, we do that now. Note that + // this might either be a flight in our handshake, or a TLS alert + // record if we decided to abort due to some failure. + auto [written, writeEc] = co_await async_write_some_awaitable(stream); + ec = writeEc; + } else { + // If we have no more data from the peer to process and no more data + // to be sent to the peer... + + // ... we first ensure that no TLS protocol error was detected until now. + // Otherwise, the handshake is aborted with an error code. + stream.handle_tls_protocol_errors(ec); + if(ec) { + break; + } - // If we received data from the peer, we hand it to the native - // handle for processing. When enough bytes were received this will - // result in the advancement of the handshake state and produce data - // in the output buffer. - if(!ec && bytesTransferred > 0) { - boost::asio::const_buffer read_buffer{m_stream.input_buffer().data(), bytesTransferred}; - m_stream.process_encrypted_data(read_buffer); - } + size_t bytesRead = co_await stream.next_layer().async_read_some( + stream.input_buffer(), boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + boost::asio::const_buffer read_buffer{stream.input_buffer().data(), bytesRead}; + if(!ec && bytesRead > 0) { + stream.process_encrypted_data(read_buffer); + } + } + } - // If we have data to be sent to the peer, we do that now. Note that - // this might either be a flight in our handshake, or a TLS alert - // record if we decided to abort due to some failure. - if(!ec && m_stream.has_data_to_send()) { - // Note: we construct `AsyncWriteOperation` with 0 as its last parameter (`plainBytesTransferred`). This - // operation will eventually call `*this` as its own handler, passing the 0 back to this call operator. - // This is necessary because the check of `bytesTransferred > 0` assumes that `bytesTransferred` bytes - // were just read and are available in input_buffer for further processing. - AsyncWriteOperation::type, Stream, Allocator>, - Stream, - Allocator> - op{std::move(*this), m_stream, 0}; - return; - } + co_return ec; +} - // If we have no more data from the peer to process and no more data - // to be sent to the peer... - if(!ec && !m_stream.native_handle()->is_handshake_complete()) { - // ... we first ensure that no TLS protocol error was detected until now. - // Otherwise the handshake is aborted with an error code. - m_stream.handle_tls_protocol_errors(ec); +/** + * Perform a DTLS handshake with the peer. See async_handshake_awaitable for details. + */ +template +boost::asio::awaitable async_handshake_awaitable_dtls(Stream& stream) { + boost::system::error_code ec; + boost::asio::steady_timer timer{stream.get_executor()}; + using namespace boost::asio::experimental::awaitable_operators; + while(!ec && !stream.native_handle()->is_handshake_complete()) { + if(stream.has_data_to_send()) { + auto [written, writeEc] = co_await async_write_some_awaitable(stream); + ec = writeEc; + } else { + // ... we first ensure that no TLS protocol error was detected until now. + // Otherwise, the handshake is aborted with an error code. + stream.handle_tls_protocol_errors(ec); + if(ec) { + break; + } - if(!ec) { - // The handshake is neither finished nor aborted. Wait for - // more data from the peer. - m_stream.next_layer().async_read_some(m_stream.input_buffer(), std::move(*this)); - return; - } - } + timer.expires_from_now(std::chrono::milliseconds(1000)); - if(!isContinuation) { - // Make sure the handler is not called without an intermediate initiating function. - // "Reading" into a zero-byte buffer will complete immediately. - m_ec = ec; - yield m_stream.next_layer().async_read_some(boost::asio::mutable_buffer(), std::move(*this)); - ec = m_ec; - } + std::variant result = + co_await (stream.next_layer().async_read_some(stream.input_buffer(), boost::asio::use_awaitable) || + timer.async_wait(boost::asio::use_awaitable)); - this->complete_now(ec); + if(result.index() == 0) { + const auto& bytesRead = std::get<0>(result); + boost::asio::const_buffer read_buffer{stream.input_buffer().data(), bytesRead}; + stream.process_encrypted_data(read_buffer); } + // If we didn't receive packet, we maybe need to retransmit or + // if we received a packet, but we couldn't move on to the next state + // then the remote is probably retransmitting as it didn't receive our ACK. + // thus we need to check if we need to retransmit. + stream.native_handle()->timeout_check(); } + } - private: - Stream& m_stream; - boost::system::error_code m_ec; - boost::system::error_code m_stashed_ec; -}; + co_return ec; +} } // namespace Botan::TLS::detail diff --git a/src/lib/tls/asio/asio_stream.h b/src/lib/tls/asio/asio_stream.h index 3ec960e63c0..0285c68e982 100644 --- a/src/lib/tls/asio/asio_stream.h +++ b/src/lib/tls/asio/asio_stream.h @@ -28,8 +28,8 @@ // We need to define BOOST_ASIO_DISABLE_SERIAL_PORT before any asio imports. Otherwise asio will include , // which interferes with Botan's amalgamation by defining macros like 'B0' and 'FF1'. #define BOOST_ASIO_DISABLE_SERIAL_PORT - #include #include + #include #include #include @@ -55,19 +55,9 @@ class Stream; * future major version of Botan will therefor consume instances of this class * as a std::unique_ptr. The current usage of std::shared_ptr is erratic. */ -class StreamCallbacks : public Callbacks { +class StreamCallbacksBase : public Callbacks { public: - StreamCallbacks() {} - - void tls_emit_data(std::span data) final { - m_send_buffer.commit(boost::asio::buffer_copy(m_send_buffer.prepare(data.size()), - boost::asio::buffer(data.data(), data.size()))); - } - - void tls_record_received(uint64_t, std::span data) final { - m_receive_buffer.commit(boost::asio::buffer_copy(m_receive_buffer.prepare(data.size()), - boost::asio::const_buffer(data.data(), data.size()))); - } + StreamCallbacksBase() = default; bool tls_peer_closed_connection() final { // Instruct the TLS implementation to reply with our close_notify to @@ -115,16 +105,6 @@ class StreamCallbacks : public Callbacks { void set_context(std::weak_ptr context) { m_context = std::move(context); } - void consume_send_buffer() { m_send_buffer.consume(m_send_buffer.size()); } - - boost::beast::flat_buffer& send_buffer() { return m_send_buffer; } - - const boost::beast::flat_buffer& send_buffer() const { return m_send_buffer; } - - boost::beast::flat_buffer& receive_buffer() { return m_receive_buffer; } - - const boost::beast::flat_buffer& receive_buffer() const { return m_receive_buffer; } - bool shutdown_received() const { return m_alert_from_peer && m_alert_from_peer->type() == AlertType::CloseNotify; } @@ -133,10 +113,102 @@ class StreamCallbacks : public Callbacks { private: std::optional m_alert_from_peer; + std::weak_ptr m_context; +}; + +class StreamCallbacksTLS : public StreamCallbacksBase { + void tls_emit_data(std::span data) final { + m_send_buffer.commit(boost::asio::buffer_copy(m_send_buffer.prepare(data.size()), + boost::asio::buffer(data.data(), data.size()))); + } + + void tls_record_received(uint64_t, std::span data) final { + m_receive_buffer.commit(boost::asio::buffer_copy(m_receive_buffer.prepare(data.size()), + boost::asio::const_buffer(data.data(), data.size()))); + } + + public: + StreamCallbacksTLS(size_t bufferSize = MAX_CIPHERTEXT_SIZE) : m_input_buffer(bufferSize) {} + + size_t available() const { return m_receive_buffer.size(); } + + size_t send_count_readable_bytes() const { return m_send_buffer.size(); } + + void consume_send_buffer(size_t bytes) { m_send_buffer.consume(bytes); } + + bool has_data_to_send() const { return m_send_buffer.size(); } + + bool has_received_data() const { return available(); } + + boost::asio::const_buffer send_buffer() const { return m_send_buffer.data(); } + + boost::asio::const_buffer receive_buffer_data() const { return m_receive_buffer.data(); } + + void consume_receive_buffer(size_t bytes) { m_receive_buffer.consume(bytes); } + + boost::asio::mutable_buffer input_buffer() { return boost::asio::buffer(m_input_buffer); } + + private: + std::vector m_input_buffer; // Buffer used for receiving data (before decrypt) boost::beast::flat_buffer m_receive_buffer; boost::beast::flat_buffer m_send_buffer; +}; - std::weak_ptr m_context; +class StreamCallbacksDTLS : public StreamCallbacksBase { + void tls_emit_data(std::span data) final { + m_send_buffer.push_back(std::vector(data.begin(), data.end())); + } + + void tls_record_received(uint64_t, std::span data) final { + m_receive_buffer.push_back(std::vector(data.begin(), data.end())); + } + + public: + StreamCallbacksDTLS(size_t mtu = MAX_CIPHERTEXT_SIZE) : m_input_buffer(mtu) {} + + size_t available() const { return m_receive_buffer.empty() ? 0 : m_receive_buffer.front().size(); } + + size_t send_count_readable_bytes() const { return m_send_buffer.empty() ? 0 : m_send_buffer.front().size(); } + + void consume_send_buffer(size_t bytes) { + // pop full messages + size_t consumed = 0; + while(bytes > consumed && !m_send_buffer.empty()) { + consumed += m_send_buffer.front().size(); + m_send_buffer.pop_front(); + } + } + + bool has_data_to_send() const { return !m_send_buffer.empty(); } + + bool has_received_data() const { return available(); } + + boost::asio::const_buffer send_buffer() const { + return m_send_buffer.empty() ? boost::asio::const_buffer() + : boost::asio::const_buffer(boost::asio::buffer(m_send_buffer.front())); + } + + boost::asio::const_buffer receive_buffer_data() const { + return m_receive_buffer.empty() ? boost::asio::const_buffer() + : boost::asio::const_buffer(boost::asio::buffer(m_receive_buffer.front())); + }; + + void consume_receive_buffer(size_t bytes) { + // pop full messages + size_t consumed = 0; + while(bytes > consumed && !m_receive_buffer.empty()) { + consumed += m_receive_buffer.front().size(); + m_receive_buffer.pop_front(); + } + } + + boost::asio::mutable_buffer input_buffer() { return boost::asio::buffer(m_input_buffer); } + + private: + std::vector m_input_buffer; // Buffer used for receiving data (before decrypt) + // deque has poor performance on some compilers, so we use devector instead + boost::container::devector> m_receive_buffer; // Decrypted data + boost::container::devector> m_send_buffer; }; namespace detail { @@ -162,6 +234,9 @@ class Stream { boost::asio::default_completion_token_t>; public: + static constexpr bool m_is_dtls = !std::is_same_v; + using StreamCallbacksType = std::conditional_t; + //! \name construction //! @{ @@ -178,12 +253,10 @@ class Stream { * @param args Arguments to be forwarded to the construction of the next layer. */ template - explicit Stream(std::shared_ptr context, std::shared_ptr callbacks, Args&&... args) : - m_context(std::move(context)), - m_nextLayer(std::forward(args)...), - m_core(std::move(callbacks)), - m_input_buffer_space(MAX_CIPHERTEXT_SIZE, '\0'), - m_input_buffer(m_input_buffer_space.data(), m_input_buffer_space.size()) { + explicit Stream(std::shared_ptr context, + std::shared_ptr callbacks, + Args&&... args) : + m_context(std::move(context)), m_nextLayer(std::forward(args)...), m_core(std::move(callbacks)) { m_core->set_context(m_context); } @@ -195,7 +268,7 @@ class Stream { */ template explicit Stream(std::shared_ptr context, Args&&... args) : - Stream(std::move(context), std::make_shared(), std::forward(args)...) {} + Stream(std::move(context), std::make_shared(), std::forward(args)...) {} /** * @brief Construct a new Stream @@ -211,7 +284,7 @@ class Stream { template explicit Stream(Arg&& arg, std::shared_ptr context, - std::shared_ptr callbacks = std::make_shared()) : + std::shared_ptr callbacks = std::make_shared()) : Stream(std::move(context), std::move(callbacks), std::forward(arg)) {} virtual ~Stream() = default; @@ -249,6 +322,13 @@ class Stream { return m_native_handle.get(); } + const native_handle_type native_handle() const { + if(m_native_handle == nullptr) { + throw Invalid_State("Invalid handshake state"); + } + return m_native_handle.get(); + } + //! @} //! \name configuration and callback setters //! @{ @@ -386,14 +466,41 @@ class Stream { auto async_handshake(Botan::TLS::Connection_Side side, CompletionToken&& completion_token = default_completion_token{}) { return boost::asio::async_initiate( - [this](auto&& completion_handler, TLS::Connection_Side connection_side) { - using completion_handler_t = std::decay_t; - + [this](CallbackType&& completion_handler, TLS::Connection_Side connection_side) { boost::system::error_code ec; setup_native_handle(connection_side, ec); - - detail::AsyncHandshakeOperation op{ - std::forward(completion_handler), *this, ec}; + boost::asio::co_spawn( + get_executor(), + [this]() mutable -> boost::asio::awaitable { + if constexpr(m_is_dtls) { + boost::asio::steady_timer handshake_max_time_guard{get_executor()}; + handshake_max_time_guard.expires_after(std::chrono::seconds{6}); + using namespace boost::asio::experimental::awaitable_operators; + std::variant handshake_result = + co_await (detail::async_handshake_awaitable_dtls(*this) || + handshake_max_time_guard.async_wait(boost::asio::use_awaitable)); + if(handshake_result.index() == 0) { + co_return std::get<0>(handshake_result); + } else { + co_return boost::system::error_code{boost::asio::error::timed_out}; + } + } else { + co_return co_await detail::async_handshake_awaitable(*this); + } + }, + [completion_handler = std::forward(completion_handler)]( + std::exception_ptr eptr, const boost::system::error_code& ec) { + boost::system::error_code tmp_code = ec; + if(eptr) { + try { + } catch(boost::system::system_error& e) { + tmp_code = e.code(); + } catch(...) { + std::rethrow_exception(eptr); + } + } + completion_handler(tmp_code); + }); }, completion_token, side); @@ -444,6 +551,11 @@ class Stream { boost::asio::detail::throw_error(ec, "shutdown"); } + size_t available() const { return m_core->available(); } + + // TODO: should return error? + size_t available(boost::system::error_code& /*ec*/) const { return m_core->available(); } + private: /** * @brief Internal wrapper type to adapt the expected signature of `async_shutdown` to the completion handler @@ -518,6 +630,15 @@ class Stream { */ template std::size_t read_some(const MutableBufferSequence& buffers, boost::system::error_code& ec) { + if(has_received_data()) { + return copy_received_data(buffers); + } + size_t bytes_read = m_nextLayer.read_some(input_buffer(), ec); + boost::asio::const_buffer read_buffer(input_buffer().data(), bytes_read); + if(ec) { + return 0; + } + // We read from the socket until either some error occured or we have // decrypted at least one byte of application data. while(!ec) { @@ -622,7 +743,7 @@ class Stream { if(ec) { // we cannot be sure how many bytes were committed here so clear the send_buffer and let the // AsyncWriteOperation call the handler with the error_code set - m_core->send_buffer().consume(m_core->send_buffer().size()); + consume_send_buffer(m_core->send_count_readable_bytes()); } detail::AsyncWriteOperation op{ @@ -672,15 +793,22 @@ class Stream { friend class detail::AsyncReadOperation; template friend class detail::AsyncWriteOperation; - template - friend class detail::AsyncHandshakeOperation; - const boost::asio::mutable_buffer& input_buffer() { return m_input_buffer; } + friend boost::asio::awaitable detail::async_handshake_awaitable_dtls( + Stream& stream); + + friend boost::asio::awaitable detail::async_handshake_awaitable( + Stream& stream); + + friend boost::asio::awaitable> + detail::async_write_some_awaitable(Stream& stream); - boost::asio::const_buffer send_buffer() const { return m_core->send_buffer().data(); } + boost::asio::mutable_buffer input_buffer() { return m_core->input_buffer(); } + + boost::asio::const_buffer send_buffer() const { return m_core->send_buffer(); } //! @brief Check if decrypted data is available in the receive buffer - bool has_received_data() const { return m_core->receive_buffer().size() > 0; } + bool has_received_data() const { return m_core->has_received_data(); } //! @brief Copy decrypted data into the user-provided buffer template @@ -689,16 +817,19 @@ class Stream { // the user's desired target buffer once a read is started, and reading directly into that buffer in tls_record // received. However, we need to deal with the case that the receive buffer provided by the caller is smaller // than the decrypted record, so this optimization might not be worth the additional complexity. - const auto copiedBytes = boost::asio::buffer_copy(buffers, m_core->receive_buffer().data()); - m_core->receive_buffer().consume(copiedBytes); + const auto copiedBytes = boost::asio::buffer_copy(buffers, m_core->receive_buffer_data()); + m_core->consume_receive_buffer(copiedBytes); return copiedBytes; } //! @brief Check if encrypted data is available in the send buffer - bool has_data_to_send() const { return m_core->send_buffer().size() > 0; } + bool has_data_to_send() const { return m_core->has_data_to_send() > 0; } //! @brief Mark bytes in the send buffer as consumed, removing them from the buffer - void consume_send_buffer(std::size_t bytesConsumed) { m_core->send_buffer().consume(bytesConsumed); } + void consume_send_buffer(std::size_t bytesConsumed) { m_core->consume_send_buffer(bytesConsumed); } + + //! @brief Mark bytes in the receive buffer as consumed, removing them from the buffer + void consume_receive_buffer(std::size_t bytesConsumed) { m_core->consume_receive_buffer(bytesConsumed); } /** * @brief Create the native handle. @@ -718,21 +849,21 @@ class Stream { try_with_error_code( [&] { if(side == Connection_Side::Client) { - m_native_handle = std::unique_ptr( - new Client(m_core, - m_context->m_session_manager, - m_context->m_credentials_manager, - m_context->m_policy, - m_context->m_rng, - m_context->m_server_info, - m_context->m_policy->latest_supported_version(false /* no DTLS */))); + m_native_handle = + std::unique_ptr(new Client(m_core, + m_context->m_session_manager, + m_context->m_credentials_manager, + m_context->m_policy, + m_context->m_rng, + m_context->m_server_info, + m_context->m_policy->latest_supported_version(m_is_dtls))); } else { m_native_handle = std::unique_ptr(new Server(m_core, m_context->m_session_manager, m_context->m_credentials_manager, m_context->m_policy, m_context->m_rng, - false /* no DTLS */)); + m_is_dtls)); } }, ec); @@ -923,7 +1054,7 @@ class Stream { std::shared_ptr m_context; StreamLayer m_nextLayer; - std::shared_ptr m_core; + std::shared_ptr m_core; std::unique_ptr m_native_handle; boost::system::error_code m_ec_from_last_read; diff --git a/src/lib/tls/tls12/tls_channel_impl_12.cpp b/src/lib/tls/tls12/tls_channel_impl_12.cpp index 5adcc7de7dc..2a723c3e7c5 100644 --- a/src/lib/tls/tls12/tls_channel_impl_12.cpp +++ b/src/lib/tls/tls12/tls_channel_impl_12.cpp @@ -343,10 +343,12 @@ size_t Channel_Impl_12::from_peer(std::span data) { if(m_has_been_closed) { throw TLS_Exception(Alert::UnexpectedMessage, "Received application data after connection closure"); } - if(pending_state() != nullptr) { + // If we're in a handshake we can possibly ignore the data for DTLS. It's an error for TLS. + if(pending_state() == nullptr) { + process_application_data(record.sequence(), m_record_buf); + } else if(!m_is_datagram) { throw TLS_Exception(Alert::UnexpectedMessage, "Can't interleave application and handshake data"); } - process_application_data(record.sequence(), m_record_buf); } else if(record.type() == Record_Type::Alert) { process_alert(m_record_buf); } else if(record.type() != Record_Type::Invalid) { diff --git a/src/lib/tls/tls12/tls_handshake_io.cpp b/src/lib/tls/tls12/tls_handshake_io.cpp index c6be0a0a205..1273a0a7f29 100644 --- a/src/lib/tls/tls12/tls_handshake_io.cpp +++ b/src/lib/tls/tls12/tls_handshake_io.cpp @@ -367,7 +367,6 @@ std::vector Datagram_Handshake_IO::send_under_epoch(const Handshake_Mes m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits); m_out_message_seq += 1; - m_last_write = steady_clock_ms(); m_next_timeout = m_initial_timeout; return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits); @@ -378,7 +377,7 @@ std::vector Datagram_Handshake_IO::send_message(uint16_t msg_seq, Handshake_Type msg_type, const std::vector& msg_bits) { const size_t DTLS_HANDSHAKE_HEADER_LEN = 12; - + m_last_write = steady_clock_ms(); auto no_fragment = format_w_seq(msg_bits, msg_type, msg_seq); if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) {