diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 3d95be000..5db860ea7 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -80,23 +80,16 @@ class MooncakeStorePyWrapper { try { // Section with GIL released py::gil_scoped_release release_gil; - auto buffer_handle = store_.get_buffer(key); - if (!buffer_handle) { - py::gil_scoped_acquire acquire_gil; - return pybind11::none(); - } - // Create contiguous buffer and copy data - auto total_length = buffer_handle->size(); - char *exported_data = new char[total_length]; - if (!exported_data) { + uint64_t total_length = 0; + auto get_result = store_.get_allocated_internal(key, total_length); + if (!get_result) { py::gil_scoped_acquire acquire_gil; - LOG(ERROR) << "Invalid data format: insufficient data for " - "metadata"; return pybind11::none(); } + auto exported_data = *get_result; + + // Copy metadata from buffer TensorMetadata metadata; - // Copy data from buffer to contiguous memory - memcpy(exported_data, buffer_handle->ptr(), total_length); memcpy(&metadata, exported_data, sizeof(TensorMetadata)); if (metadata.ndim < 0 || metadata.ndim > 4) { diff --git a/mooncake-store/include/client_buffer.hpp b/mooncake-store/include/client_buffer.hpp index 8d1956445..8d732013c 100644 --- a/mooncake-store/include/client_buffer.hpp +++ b/mooncake-store/include/client_buffer.hpp @@ -97,6 +97,17 @@ std::vector split_into_slices(BufferHandle& handle); */ uint64_t calculate_total_size(const Replica::Descriptor& replica); +/** + * @brief Allocate slices from a buffer pointer based on replica descriptor + * @param slices Output vector to store the allocated slices + * @param replica The replica descriptor defining the slice structure + * @param buffer The buffer pointer to allocate slices from + * @return 0 on success, non-zero on error + */ +int allocateSlices(std::vector& slices, + const Replica::Descriptor& replica, + char* buffer); + /** * @brief Allocate slices from a buffer handle based on replica descriptor * @param slices Output vector to store the allocated slices diff --git a/mooncake-store/include/pybind_client.h b/mooncake-store/include/pybind_client.h index 207c9449d..5d9adc1e4 100644 --- a/mooncake-store/include/pybind_client.h +++ b/mooncake-store/include/pybind_client.h @@ -273,6 +273,9 @@ class PyClient { const std::vector> &values, const ReplicateConfig &config = ReplicateConfig{}); + tl::expected get_allocated_internal( + const std::string &key, uint64_t &data_length); + tl::expected remove_internal(const std::string &key); tl::expected removeByRegex_internal( diff --git a/mooncake-store/src/client_buffer.cpp b/mooncake-store/src/client_buffer.cpp index f715e7d2f..b820a7361 100644 --- a/mooncake-store/src/client_buffer.cpp +++ b/mooncake-store/src/client_buffer.cpp @@ -78,15 +78,14 @@ uint64_t calculate_total_size(const Replica::Descriptor& replica) { } int allocateSlices(std::vector& slices, - const Replica::Descriptor& replica, - BufferHandle& buffer_handle) { + const Replica::Descriptor& replica, char* buffer) { uint64_t offset = 0; if (replica.is_memory_replica() == false) { // For disk-based replica, split into slices based on file size uint64_t total_length = replica.get_disk_descriptor().object_size; while (offset < total_length) { auto chunk_size = std::min(total_length - offset, kMaxSliceSize); - void* chunk_ptr = static_cast(buffer_handle.ptr()) + offset; + void* chunk_ptr = buffer + offset; slices.emplace_back(Slice{chunk_ptr, chunk_size}); offset += chunk_size; } @@ -95,7 +94,7 @@ int allocateSlices(std::vector& slices, // descriptors for (auto& handle : replica.get_memory_descriptor().buffer_descriptors) { - void* chunk_ptr = static_cast(buffer_handle.ptr()) + offset; + void* chunk_ptr = buffer + offset; slices.emplace_back(Slice{chunk_ptr, handle.size_}); offset += handle.size_; } @@ -103,4 +102,11 @@ int allocateSlices(std::vector& slices, return 0; } +int allocateSlices(std::vector& slices, + const Replica::Descriptor& replica, + BufferHandle& buffer_handle) { + return allocateSlices(slices, replica, + static_cast(buffer_handle.ptr())); +} + } // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/src/pybind_client.cpp b/mooncake-store/src/pybind_client.cpp index c0d9e06a9..b484f52e0 100644 --- a/mooncake-store/src/pybind_client.cpp +++ b/mooncake-store/src/pybind_client.cpp @@ -1091,4 +1091,67 @@ int PyClient::put_from_with_metadata(const std::string &key, void *buffer, return 0; } +tl::expected PyClient::get_allocated_internal( + const std::string &key, uint64_t &data_length) { + // Query object info first + auto query_result = client_->Query(key); + if (!query_result) { + LOG(ERROR) << "Query failed: " << query_result.error(); + return tl::unexpected(query_result.error()); + } + + auto replica_list = query_result.value(); + if (replica_list.empty()) { + LOG(INFO) << "No replicas found for key: " << key; + return tl::unexpected(ErrorCode::INVALID_KEY); + } + + const auto &replica = replica_list[0]; + uint64_t total_length = calculate_total_size(replica); + if (total_length == 0) { + LOG(ERROR) << "Zero length value for key: " << key; + return tl::unexpected(ErrorCode::INVALID_KEY); + } + + // Create contiguous buffer to read data + char *data_ptr = new char[total_length]; + if (!data_ptr) { + LOG(ERROR) << "Failed to allocate memory for length: " << total_length; + return tl::unexpected(ErrorCode::INTERNAL_ERROR); + } + + // register the buffer + auto register_result = register_buffer_internal( + reinterpret_cast(data_ptr), total_length); + if (!register_result) { + LOG(ERROR) << "Failed to register buffer"; + delete[] data_ptr; + return tl::unexpected(register_result.error()); + } + + // Create slices for the allocated buffer + std::vector slices; + allocateSlices(slices, replica, data_ptr); + + // Get the object data + auto get_result = client_->Get(key, replica_list, slices); + + // unregister the buffer for whatever cases + auto unregister_result = + unregister_buffer_internal(reinterpret_cast(data_ptr)); + if (!unregister_result) { + LOG(WARNING) << "Failed to unregister buffer"; + } + + if (!get_result) { + delete[] data_ptr; + LOG(ERROR) << "Get failed for key: " << key; + return tl::unexpected(get_result.error()); + } + + // return the data ptr transferring the ownership to the caller + data_length = total_length; + return data_ptr; +} + } // namespace mooncake