diff --git a/include/condy.hpp b/include/condy.hpp index e32f1eb..c1bb402 100644 --- a/include/condy.hpp +++ b/include/condy.hpp @@ -21,6 +21,7 @@ #include "condy/sync_wait.hpp" // IWYU pragma: export #include "condy/task.hpp" // IWYU pragma: export #include "condy/version.hpp" // IWYU pragma: export +#include "condy/zcrx.hpp" // IWYU pragma: export /** * @brief The main namespace for the Condy library. diff --git a/include/condy/async_operations.hpp b/include/condy/async_operations.hpp index 4d8b83b..b801d59 100644 --- a/include/condy/async_operations.hpp +++ b/include/condy/async_operations.hpp @@ -12,6 +12,7 @@ #include "condy/concepts.hpp" #include "condy/condy_uring.hpp" #include "condy/helpers.hpp" +#include "condy/zcrx.hpp" namespace condy { @@ -701,6 +702,38 @@ inline auto async_recv_multishot(Fd sockfd, Buffer &buf, int flags, } #endif +#if !IO_URING_CHECK_VERSION(2, 15) // >= 2.15 + +namespace detail { + +inline void prep_recv_zc_multishot(io_uring_sqe *sqe, int fd, + uint32_t zcrx_id) { + io_uring_prep_rw(IORING_OP_RECV_ZC, sqe, fd, nullptr, 0, 0); + sqe->ioprio |= IORING_RECV_MULTISHOT; + sqe->zcrx_ifq_idx = zcrx_id; +} + +} // namespace detail + +// TODO: Consider the function signature later... +template +inline auto async_recv_multishot(Fd fd, ZeroCopyRxBufferPool &pool, + [[maybe_unused]] int flags, + MultiShotFunc &&func) { + auto zcrx_id = pool.zcrx_id(); + auto prep_func = [=](Ring *ring) { + auto *sqe = ring->get_sqe(); + detail::prep_recv_zc_multishot(sqe, fd, zcrx_id); + return sqe; + }; + auto op = build_multishot_op_awaiter< + SelectBufferCQEHandler>( + std::move(prep_func), std::forward(func), &pool); + return detail::maybe_flag_fixed_fd(std::move(op), fd); +} + +#endif + /** * @brief See io_uring_prep_openat2 */ diff --git a/include/condy/buffers.hpp b/include/condy/buffers.hpp index 155a32e..771ce8e 100644 --- a/include/condy/buffers.hpp +++ b/include/condy/buffers.hpp @@ -14,6 +14,7 @@ #include #include #include +#include #include namespace condy { @@ -190,4 +191,66 @@ inline MutableBuffer buffer(std::span sp) noexcept { sp.size() * sizeof(PodType)); } +namespace detail { + +template struct ManagedBuffer : public BufferBase { +public: + ManagedBuffer() = default; + ManagedBuffer(void *data, size_t size, BufferPool *pool) + : data_(data), size_(size), pool_(pool) {} + ManagedBuffer(ManagedBuffer &&other) noexcept + : data_(std::exchange(other.data_, nullptr)), + size_(std::exchange(other.size_, 0)), + pool_(std::exchange(other.pool_, nullptr)) {} + ManagedBuffer &operator=(ManagedBuffer &&other) noexcept { + if (this != &other) { + reset(); + data_ = std::exchange(other.data_, nullptr); + size_ = std::exchange(other.size_, 0); + pool_ = std::exchange(other.pool_, nullptr); + } + return *this; + } + + ~ManagedBuffer() { reset(); } + + ManagedBuffer(const ManagedBuffer &) = delete; + ManagedBuffer &operator=(const ManagedBuffer &) = delete; + +public: + /** + * @brief Get the data pointer of the buffer + */ + void *data() const noexcept { return data_; } + + /** * + * @brief Get the size of the buffer + */ + size_t size() const noexcept { return size_; } + + /** + * @brief Reset the buffer, returning it to the pool if owned + */ + void reset() noexcept { + if (pool_ != nullptr) { + pool_->add_buffer_back(data_, size_); + } + data_ = nullptr; + size_ = 0; + pool_ = nullptr; + } + + /** + * @brief Check if the buffer owns a buffer from a pool. + */ + bool owns_buffer() const noexcept { return pool_ != nullptr; } + +private: + void *data_ = nullptr; + size_t size_ = 0; + BufferPool *pool_ = nullptr; +}; + +} // namespace detail + } // namespace condy \ No newline at end of file diff --git a/include/condy/cqe_handler.hpp b/include/condy/cqe_handler.hpp index 188feb0..803e194 100644 --- a/include/condy/cqe_handler.hpp +++ b/include/condy/cqe_handler.hpp @@ -55,7 +55,9 @@ struct SimpleCQEHandler { * result of the operation (the value of `cqe->res`) and the selected buffer, * whose type is determined by the buffer ring. */ -template class SelectBufferCQEHandler { +template + requires(requires(Br *br, io_uring_cqe *cqe) { br->handle_finish(cqe); }) +class SelectBufferCQEHandler { public: SelectBufferCQEHandler(Br *buffers) : buffers_(buffers) {} diff --git a/include/condy/provided_buffers.hpp b/include/condy/provided_buffers.hpp index fc047e7..e34ac15 100644 --- a/include/condy/provided_buffers.hpp +++ b/include/condy/provided_buffers.hpp @@ -227,57 +227,7 @@ class BundledProvidedBufferPool; * @note The lifetime of the provided buffer must not exceed the lifetime of the * provided buffer pool it is associated with. */ -struct ProvidedBuffer : public BufferBase { -public: - ProvidedBuffer() = default; - ProvidedBuffer(void *data, size_t size, - detail::BundledProvidedBufferPool *pool) - : data_(data), size_(size), pool_(pool) {} - ProvidedBuffer(ProvidedBuffer &&other) noexcept - : data_(std::exchange(other.data_, nullptr)), - size_(std::exchange(other.size_, 0)), - pool_(std::exchange(other.pool_, nullptr)) {} - ProvidedBuffer &operator=(ProvidedBuffer &&other) noexcept { - if (this != &other) { - reset(); - data_ = std::exchange(other.data_, nullptr); - size_ = std::exchange(other.size_, 0); - pool_ = std::exchange(other.pool_, nullptr); - } - return *this; - } - - ~ProvidedBuffer() { reset(); } - - ProvidedBuffer(const ProvidedBuffer &) = delete; - ProvidedBuffer &operator=(const ProvidedBuffer &) = delete; - -public: - /** - * @brief Get the data pointer of the provided buffer - */ - void *data() const noexcept { return data_; } - - /** * - * @brief Get the size of the provided buffer - */ - size_t size() const noexcept { return size_; } - - /** - * @brief Reset the provided buffer, returning it to the pool if owned - */ - void reset() noexcept; - - /** - * @brief Check if the provided buffer owns a buffer from a pool. - */ - bool owns_buffer() const noexcept { return pool_ != nullptr; } - -private: - void *data_ = nullptr; - size_t size_ = 0; - detail::BundledProvidedBufferPool *pool_ = nullptr; -}; +using ProvidedBuffer = detail::ManagedBuffer; namespace detail { @@ -397,7 +347,8 @@ class BundledProvidedBufferPool { return buffers; } - void add_buffer_back(void *ptr) noexcept { + void add_buffer_back(void *ptr, [[maybe_unused]] size_t size) noexcept { + assert(size <= buffer_size_); char *base = get_buffers_base_(); assert(ptr >= base); size_t offset = static_cast(ptr) - base; @@ -437,15 +388,6 @@ class BundledProvidedBufferPool { } // namespace detail -inline void ProvidedBuffer::reset() noexcept { - if (pool_ != nullptr) { - pool_->add_buffer_back(data_); - } - data_ = nullptr; - size_ = 0; - pool_ = nullptr; -} - /** * @brief Provided buffer pool. * @details A provided buffer pool manages a pool of buffers that can be used in diff --git a/include/condy/utils.hpp b/include/condy/utils.hpp index c6c4b90..539e08a 100644 --- a/include/condy/utils.hpp +++ b/include/condy/utils.hpp @@ -227,4 +227,10 @@ std::variant tuple_at(std::tuple &results, size_t idx) { } } +template inline T align_up(T value, size_t alignment) noexcept { + // alignment must be a power of two + assert(alignment > 0 && (alignment & (alignment - 1)) == 0); + return (value + alignment - 1) & ~(alignment - 1); +} + } // namespace condy diff --git a/include/condy/zcrx.hpp b/include/condy/zcrx.hpp new file mode 100644 index 0000000..bfd909c --- /dev/null +++ b/include/condy/zcrx.hpp @@ -0,0 +1,280 @@ +#pragma once + +#include "condy/buffers.hpp" +#include "condy/condy_uring.hpp" +#include "condy/context.hpp" +#include "condy/ring.hpp" +#include "condy/utils.hpp" + +namespace condy { + +#if !IO_URING_CHECK_VERSION(2, 15) // >= 2.15 + +class ZeroCopyRxBufferPool; + +/** + * @brief Buffer from a ZeroCopyRxBufferPool. + * @details This buffer type is used for buffers obtained from a + * ZeroCopyRxBufferPool. It automatically returns the buffer to the pool when it + * is out of scope. + * @note The lifetime of the buffer must not exceed the lifetime of the + * ZeroCopyRxBufferPool it is associated with. + */ +using ZeroCopyRxBuffer = detail::ManagedBuffer; + +/** + * @brief Area for zero-copy receive buffers. + */ +struct ZeroCopyRxArea { + void *addr = nullptr; + size_t size; +}; + +/** + * @brief Area for zero-copy receive buffers using DMA-BUF. + */ +struct ZeroCopyRxDMABufArea { + int dmabuf_fd; + size_t offset; + size_t size; +}; + +/** + * @brief Buffer pool for zero-copy receive buffers. + * @details This buffer pool utilizes the io_uring zcrx feature to provide + * zero-copy receive buffers. It can be used to receive data directly into + * user-space buffers without copying, which can improve performance for + * high-throughput network applications. + * @returns std::pair When passed to async + * operations, the return type will be a pair of the operation result and the + * @ref ZeroCopyRxBuffer. + * @note The lifetime of this pool must not exceed the running period of the + * associated Runtime, and the lifetime of any ZeroCopyRxBuffer obtained from + * this pool must not exceed the lifetime of this pool. + */ +class ZeroCopyRxBufferPool { +public: + /** + * @brief Construct a new Zero Copy Rx Buffer Pool object + * @param if_idx Network interface index to register the buffer pool with. + * @param if_rxq Receive queue index to register the buffer pool with. + * @param rq_entries Number of receive queue entries. + * @param area Area for zero-copy receive buffers. + */ + ZeroCopyRxBufferPool(uint32_t if_idx, uint32_t if_rxq, uint32_t rq_entries, + const ZeroCopyRxArea &area) + : ZeroCopyRxBufferPool(if_idx, if_rxq, rq_entries, area, 0) {} + + // Device-less constructor, DO NOT use this in production code if you don't + // know what you are doing. + ZeroCopyRxBufferPool(uint32_t rq_entries, const ZeroCopyRxArea &area) + : ZeroCopyRxBufferPool(0, 0, rq_entries, area, ZCRX_REG_NODEV) { + device_less_ = true; + } + + /** + * @brief Construct a new Zero Copy Rx Buffer Pool object + * @param if_idx Network interface index to register the buffer pool with. + * @param if_rxq Receive queue index to register the buffer pool with. + * @param rq_entries Number of receive queue entries. + * @param area Area for zero-copy receive buffers using DMA-BUF. + */ + ZeroCopyRxBufferPool(uint32_t if_idx, uint32_t if_rxq, uint32_t rq_entries, + const ZeroCopyRxDMABufArea &area) { + area_size_ = 0; + area_ptr_ = nullptr; + + io_uring_zcrx_area_reg area_reg = {}; + area_reg.addr = area.offset; + area_reg.len = area.size; + area_reg.flags = IORING_ZCRX_AREA_DMABUF; + + register_ifq_(if_idx, if_rxq, rq_entries, area_reg, + sysconf(_SC_PAGESIZE), 0); + } + + ~ZeroCopyRxBufferPool() { + [[maybe_unused]] int r; + if (area_size_ > 0) { + assert(area_ptr_ != nullptr); + r = munmap(area_ptr_, area_size_); + assert(r == 0); + } + assert(rq_ring_.ring_ptr != nullptr); + r = munmap(rq_ring_.ring_ptr, ring_size_); + assert(r == 0); + // TODO: Unregister ifq + } + + ZeroCopyRxBufferPool(const ZeroCopyRxBufferPool &) = delete; + ZeroCopyRxBufferPool &operator=(const ZeroCopyRxBufferPool &) = delete; + ZeroCopyRxBufferPool(ZeroCopyRxBufferPool &&) = delete; + ZeroCopyRxBufferPool &operator=(ZeroCopyRxBufferPool &&) = delete; + +private: + ZeroCopyRxBufferPool(uint32_t if_idx, uint32_t if_rxq, uint32_t rq_entries, + const ZeroCopyRxArea &area, uint32_t flags) { + const size_t page_size = sysconf(_SC_PAGESIZE); + + if (area.addr == nullptr) { + area_size_ = align_up(area.size, page_size); + area_ptr_ = mmap(nullptr, area_size_, PROT_READ | PROT_WRITE, + MAP_ANONYMOUS | MAP_PRIVATE, 0, 0); + if (area_ptr_ == MAP_FAILED) { + throw make_system_error("mmap"); + } + auto d = defer([&]() { munmap(area_ptr_, area_size_); }); + + io_uring_zcrx_area_reg area_reg = {}; + area_reg.addr = reinterpret_cast(area_ptr_); + area_reg.len = area_size_; + area_reg.flags = 0; + + register_ifq_(if_idx, if_rxq, rq_entries, area_reg, page_size, + flags); + + d.dismiss(); + } else { + // Not owned, so we don't track the size for unmapping + area_size_ = 0; + area_ptr_ = area.addr; + + io_uring_zcrx_area_reg area_reg = {}; + area_reg.addr = reinterpret_cast(area_ptr_); + area_reg.len = area.size; + area_reg.flags = 0; + + register_ifq_(if_idx, if_rxq, rq_entries, area_reg, page_size, + flags); + } + } + +public: + uint32_t zcrx_id() const noexcept { return zcrx_id_; } + + ZeroCopyRxBuffer handle_finish(io_uring_cqe *cqe) noexcept { + if (cqe->res < 0) { + return ZeroCopyRxBuffer(); + } + io_uring_zcrx_cqe *rcqe = + reinterpret_cast(cqe + 1); + void *data = static_cast(area_ptr_) + + (rcqe->off & ~IORING_ZCRX_AREA_MASK); + size_t size = static_cast(cqe->res); + return ZeroCopyRxBuffer(data, size, this); + } + + void add_buffer_back(void *ptr, size_t size) noexcept { + rq_enqueue_(ptr, size); + maybe_flush_rq_(); + } + +private: + void register_ifq_(uint32_t if_idx, uint32_t if_rxq, uint32_t rq_entries, + io_uring_zcrx_area_reg &area_reg, size_t page_size, + uint32_t flags) { + rq_entries = std::bit_ceil(rq_entries); + io_uring_region_desc region_reg = {}; + ring_size_ = get_refill_ring_size_(rq_entries, page_size); + region_reg.user_addr = 0; + region_reg.size = ring_size_; + region_reg.flags = 0; + + io_uring_zcrx_ifq_reg reg = {}; + reg.if_idx = if_idx; + reg.if_rxq = if_rxq; + reg.rq_entries = rq_entries; + reg.area_ptr = reinterpret_cast(&area_reg); + reg.region_ptr = reinterpret_cast(®ion_reg); + reg.flags = flags; + + auto *ring = detail::Context::current().ring(); + int r = io_uring_register_ifq(ring->ring(), ®); + if (r != 0) { + throw make_system_error("io_uring_register_ifq", -r); + } + // TODO: unregister ifq if any exception + + void *ring_ptr = mmap(nullptr, ring_size_, PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_POPULATE, ring->ring()->ring_fd, + static_cast(region_reg.mmap_offset)); + if (ring_ptr == MAP_FAILED) { + throw make_system_error("mmap"); + } + rq_ring_.khead = (unsigned int *)((char *)ring_ptr + reg.offsets.head); + rq_ring_.ktail = (unsigned int *)((char *)ring_ptr + reg.offsets.tail); + rq_ring_.rqes = + (struct io_uring_zcrx_rqe *)((char *)ring_ptr + reg.offsets.rqes); + rq_ring_.rq_tail = 0; + rq_ring_.ring_entries = reg.rq_entries; + rq_ring_.ring_ptr = ring_ptr; + + zcrx_id_ = reg.zcrx_id; + area_token_ = area_reg.rq_area_token; + } + + static size_t get_refill_ring_size_(uint32_t rq_entries, + size_t page_size) noexcept { + size_t ring_size = rq_entries * sizeof(io_uring_zcrx_rqe); + ring_size += page_size; + ring_size = align_up(ring_size, page_size); + return ring_size; + } + + size_t rq_nr_queued_() const noexcept { + return rq_ring_.rq_tail - io_uring_smp_load_acquire(rq_ring_.khead); + } + + static int io_uring_register_zcrx_ctrl_(struct io_uring *ring, + struct zcrx_ctrl *ctrl) noexcept { + unsigned int opcode = IORING_REGISTER_ZCRX_CTRL; + int fd; + if (ring->int_flags & 1) { + opcode |= IORING_REGISTER_USE_REGISTERED_RING; + fd = ring->enter_ring_fd; + } else { + fd = ring->ring_fd; + } + return io_uring_register(fd, opcode, ctrl, 0); + } + + void rq_enqueue_(void *ptr, size_t size) noexcept { + assert(rq_nr_queued_() < rq_ring_.ring_entries); + io_uring_zcrx_rqe *rqe; + unsigned rq_mask = rq_ring_.ring_entries - 1; + rqe = &rq_ring_.rqes[rq_ring_.rq_tail & rq_mask]; + rqe->off = (static_cast(ptr) - static_cast(area_ptr_)) | + area_token_; + rqe->len = static_cast(size); + io_uring_smp_store_release(rq_ring_.ktail, ++rq_ring_.rq_tail); + } + + void flush_rq_() noexcept { + auto *ring = detail::Context::current().ring(); + zcrx_ctrl ctrl = {}; + ctrl.zcrx_id = zcrx_id_; + ctrl.op = ZCRX_CTRL_FLUSH_RQ; + [[maybe_unused]] int r = + io_uring_register_zcrx_ctrl_(ring->ring(), &ctrl); + assert(r == 0); + } + + void maybe_flush_rq_() noexcept { + if (rq_nr_queued_() >= rq_ring_.ring_entries || device_less_) { + flush_rq_(); + } + } + +private: + void *area_ptr_; + size_t area_size_; + size_t ring_size_; + io_uring_zcrx_rq rq_ring_; + uint32_t zcrx_id_; + uint64_t area_token_; + bool device_less_ = false; +}; + +#endif + +} // namespace condy \ No newline at end of file diff --git a/tests/test_async_operations.4.cpp b/tests/test_async_operations.4.cpp index 80e9091..236584e 100644 --- a/tests/test_async_operations.4.cpp +++ b/tests/test_async_operations.4.cpp @@ -4,6 +4,7 @@ #include "condy/runtime.hpp" #include "condy/runtime_options.hpp" #include "condy/sync_wait.hpp" +#include "condy/zcrx.hpp" #include "helpers.hpp" #include #include @@ -1079,4 +1080,58 @@ TEST_CASE("test async_operations - test pipe - direct") { }; condy::sync_wait(func()); } +#endif + +#if !IO_URING_CHECK_VERSION(2, 15) // >= 2.15 +TEST_CASE("test async_operations - test recv - zc multishot") { + int sv[2]; + create_tcp_socketpair(sv); + + condy::Runtime runtime( + condy::RuntimeOptions().enable_cqe32().enable_defer_taskrun()); + + auto msg = generate_data(9ul * 4096); + ssize_t r = send(sv[1], msg.data(), msg.size(), 0); + REQUIRE(r == msg.size()); + close(sv[1]); + + auto func = [&]() -> condy::Coro { + size_t count = 0; + std::string actual; + + condy::ZeroCopyRxBufferPool pool( + 256, condy::ZeroCopyRxArea{.size = 8ul * 4096}); + + condy::Channel channel(16); + + auto [n, buf] = + co_await condy::async_recv_multishot(sv[0], pool, 0, [&](auto res) { + auto &[n, buf] = res; + REQUIRE(n == 4096); + actual.append(static_cast(buf.data()), n); + count++; + REQUIRE(channel.try_push(std::move(buf)) == 0); + }); + REQUIRE(n == -ENOMEM); + REQUIRE(count == 8); + + auto [r, tmp] = co_await channel.pop(); + tmp.reset(); // Release the buffer back to the pool + + auto [n2, buf2] = + co_await condy::async_recv_multishot(sv[0], pool, 0, [&](auto res) { + auto &[n, buf] = res; + REQUIRE(n == 4096); + actual.append(static_cast(buf.data()), n); + count++; + }); + REQUIRE(n2 == 0); + REQUIRE(count == 9); + + REQUIRE(actual == msg); + }; + condy::sync_wait(runtime, func()); + + close(sv[0]); +} #endif \ No newline at end of file