From 79646a4ef699d54534bf392826febb9794f749a8 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 12 Dec 2025 09:48:56 +0800 Subject: [PATCH 1/2] issue/125 - cache interface --- csrc/cache/cache.hpp | 3 +- csrc/cache/cache_factory.cpp | 24 ++ csrc/cache/cache_interface.hpp | 78 +++++ csrc/cache/dynamic_cache/dynamic_cache.cpp | 283 ++++++++++++++++ csrc/cache/dynamic_cache/dynamic_cache.hpp | 166 +++++++++ csrc/cache/kv_cache.hpp | 369 --------------------- csrc/engine/rank_worker.cpp | 13 +- csrc/engine/rank_worker.hpp | 2 +- csrc/models/llama/llama_attention.cpp | 2 +- csrc/models/llama/llama_attention.hpp | 2 +- csrc/models/llama/llama_model.hpp | 6 +- csrc/models/model_factory.cpp | 2 +- csrc/models/model_factory.hpp | 2 +- csrc/pybind11/models/llama.hpp | 20 +- 14 files changed, 579 insertions(+), 393 deletions(-) create mode 100644 csrc/cache/cache_factory.cpp create mode 100644 csrc/cache/cache_interface.hpp create mode 100644 csrc/cache/dynamic_cache/dynamic_cache.cpp create mode 100644 csrc/cache/dynamic_cache/dynamic_cache.hpp delete mode 100644 csrc/cache/kv_cache.hpp diff --git a/csrc/cache/cache.hpp b/csrc/cache/cache.hpp index e8722b39..da442417 100644 --- a/csrc/cache/cache.hpp +++ b/csrc/cache/cache.hpp @@ -1,4 +1,5 @@ #pragma once #include "cache_config.hpp" -#include "kv_cache.hpp" +#include "cache_interface.hpp" +#include "dynamic_cache/dynamic_cache.hpp" diff --git a/csrc/cache/cache_factory.cpp b/csrc/cache/cache_factory.cpp new file mode 100644 index 00000000..c4d6fd91 --- /dev/null +++ b/csrc/cache/cache_factory.cpp @@ -0,0 +1,24 @@ +#include "cache_interface.hpp" +#include "dynamic_cache/dynamic_cache.hpp" +#include + +namespace infinilm::cache { + +std::shared_ptr CacheInterface::create(const CacheConfig &config) { + switch (config.type) { + case CacheType::DYNAMIC: + return std::make_shared(config); + + case CacheType::PAGED: + // Return PagedCache when implemented + // return std::make_shared(config); + spdlog::warn("PagedCache not yet implemented, falling back to DynamicCache"); + return std::make_shared(config); + + default: + spdlog::error("Unknown cache type: {}", static_cast(config.type)); + throw std::runtime_error("Unknown cache type"); + } +} + +} // namespace infinilm::cache diff --git a/csrc/cache/cache_interface.hpp b/csrc/cache/cache_interface.hpp new file mode 100644 index 00000000..fa21c4a7 --- /dev/null +++ b/csrc/cache/cache_interface.hpp @@ -0,0 +1,78 @@ +#pragma once + +#include "cache_config.hpp" +#include "infinicore/tensor.hpp" + +#include + +namespace infinilm::cache { + +/** + * @brief Abstract interface for KV cache implementations + * This allows different cache types (Dynamic, Paged, etc.) to be used interchangeably + */ +class CacheInterface { +public: + virtual ~CacheInterface() = default; + + /** + * @brief Update cache with new key and value states + * @param layer_idx Layer index for multi-layer models + * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim] + * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim] + * @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim] + */ + virtual std::pair update( + size_t layer_idx, + const infinicore::Tensor &k_new, + const infinicore::Tensor &v_new) + = 0; + + /** + * @brief Update cache (convenience method for single-layer or default layer) + */ + virtual std::pair update( + const infinicore::Tensor &k_new, + const infinicore::Tensor &v_new) { + return update(0, k_new, v_new); + } + + /** + * @brief Reset cache for all layers to a specific position + * @param pos Position to reset to (defaults to 0) + */ + virtual void reset(size_t pos = 0) = 0; + + /** + * @brief Update cache configuration + * @param new_config New cache configuration + */ + virtual void update_config(const CacheConfig &new_config) = 0; + + /** + * @brief Get current cache configuration + */ + virtual const CacheConfig &get_config() const = 0; + + /** + * @brief Get the number of layers in this cache + */ + virtual size_t num_layers() const = 0; + + /** + * @brief Get cache position for a specific layer + */ + virtual size_t cache_position(size_t layer_idx) const = 0; + + /** + * @brief Check if cache is initialized + */ + virtual bool is_initialized() const = 0; + + /** + * @brief Factory method to create cache based on configuration + */ + static std::shared_ptr create(const CacheConfig &config); +}; + +} // namespace infinilm::cache diff --git a/csrc/cache/dynamic_cache/dynamic_cache.cpp b/csrc/cache/dynamic_cache/dynamic_cache.cpp new file mode 100644 index 00000000..30e519c5 --- /dev/null +++ b/csrc/cache/dynamic_cache/dynamic_cache.cpp @@ -0,0 +1,283 @@ +#include "dynamic_cache.hpp" + +namespace infinilm::cache { + +// KVCacheLayer Implementation + +KVCacheLayer::KVCacheLayer() + : max_capacity(0), + initial_capacity(4096), + initial_batch_size(1), + growth_factor(2.0f), + initialized(false) {} + +void KVCacheLayer::ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, + size_t seq_len, infinicore::DataType dtype, + const infinicore::Device &device, const CacheConfig &cache_config) { + size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), size_t(0), [](size_t a, size_t b) { return std::max(a, b); }); + + // VALIDATION: Verify input parameters + if (num_kv_heads == 0 || head_dim == 0 || seq_len == 0) { + SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Invalid parameters - num_kv_heads: {}, head_dim: {}, seq_len: {}", + num_kv_heads, head_dim, seq_len); + throw std::runtime_error("KV cache ensure_capacity: invalid parameters"); + } + + // Store config parameters on first initialization + if (!initialized) { + initial_capacity = cache_config.initial_capacity; + initial_batch_size = cache_config.initial_batch_size; + growth_factor = cache_config.growth_factor; + } + + // Lazy initialization + if (!initialized) { + // Use max of required capacity and initial capacity from config + max_capacity = std::max(required_capacity, initial_capacity); + + // Use max of current batch size and initial batch size from config + size_t alloc_batch_size = std::max(batch_size, initial_batch_size); + + k_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim}, + dtype, device); + v_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim}, + dtype, device); + cache_positions = std::vector(alloc_batch_size, 0); + initialized = true; + + spdlog::debug("Initialized KV cache with batch_size={}, capacity={} (config: initial_batch={}, initial_capacity={})", + alloc_batch_size, max_capacity, initial_batch_size, initial_capacity); + + // VALIDATION: Verify cache was created correctly + if (k_cache->shape()[0] != alloc_batch_size || k_cache->shape()[1] != num_kv_heads || k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) { + SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization"); + throw std::runtime_error("KV cache initialization: shape mismatch"); + } + } + // Grow cache if needed using growth factor from config + else if (required_capacity > max_capacity) { + if (!cache_config.allow_expand) { + SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache expansion not allowed by config"); + throw std::runtime_error("KV cache expansion not allowed"); + } + // Calculate new capacity using growth factor + size_t new_capacity = static_cast( + std::max(static_cast(max_capacity) * growth_factor, + static_cast(required_capacity + max_capacity))); + + // Ensure we don't exceed max_position_embeddings if specified + if (cache_config.max_kv_cache_length != 0) { + new_capacity = std::min(new_capacity, cache_config.max_kv_cache_length); + } + + // Ensure we grow by at least some minimum amount + size_t min_growth = 256; + if (new_capacity - max_capacity < min_growth) { + new_capacity = max_capacity + min_growth; + } + + size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]); + if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) { + throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache."); + } + if (new_batch_size > cache_positions.size()) { + cache_positions.resize(new_batch_size, 0); + } + + auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim}, + dtype, device); + auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim}, + dtype, device); + + spdlog::debug("Growing KV cache from capacity {} to {} (growth_factor={})", + max_capacity, new_capacity, growth_factor); + + // Copy existing cache data + for (size_t b = 0; b < new_batch_size; ++b) { + size_t cache_position = cache_positions[b]; + if (cache_position > 0) { + auto k_slice = k_cache->narrow({{0, b, 1}, {2, 0, cache_position}}); + auto v_slice = v_cache->narrow({{0, b, 1}, {2, 0, cache_position}}); + k_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(k_slice); + v_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(v_slice); + } + } + + k_cache = k_new; + v_cache = v_new; + max_capacity = new_capacity; + + // VALIDATION: Verify cache was grown correctly + if (k_cache->shape()[2] != new_capacity) { + SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch"); + throw std::runtime_error("KV cache growth: capacity mismatch"); + } + } + + // VALIDATION: Final check that capacity is sufficient + if (required_capacity > max_capacity) { + SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth"); + throw std::runtime_error("KV cache ensure_capacity: capacity insufficient"); + } +} + +std::pair KVCacheLayer::update( + const infinicore::Tensor &k_new, + const infinicore::Tensor &v_new, + const CacheConfig &cache_config) { + if (k_new->ndim() != 4 || v_new->ndim() != 4) { + throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors"); + } + size_t batch_size = k_new->shape()[0]; + size_t num_kv_heads = k_new->shape()[1]; + size_t seq_len = k_new->shape()[2]; + size_t head_dim = k_new->shape()[3]; + + // Ensure capacity with cache config + ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len, + k_new->dtype(), k_new->device(), cache_config); + + // Copy new k/v into cache at current position + bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin()); + if (all_equal) { + auto cache_position = cache_positions[0]; + + auto k_dst = k_cache->narrow({{2, cache_position, seq_len}}); + auto v_dst = v_cache->narrow({{2, cache_position, seq_len}}); + k_dst->copy_from(k_new); + v_dst->copy_from(v_new); + + // Update position + cache_position += seq_len; + for (size_t b = 0; b < batch_size; ++b) { + cache_positions[b] = cache_position; + } + + // Return the total cache up to current position + auto k_total = k_cache->narrow({{2, 0, cache_position}}); + auto v_total = v_cache->narrow({{2, 0, cache_position}}); + + return std::make_pair(k_total, v_total); + } else { + throw std::runtime_error("KVCache update: cache positions must be equal among a batch."); + } +} + +// DynamicCache Implementation + +DynamicCache::DynamicCache(const CacheConfig &cache_config) + : cache_config_(cache_config), layers_(cache_config.num_layers) { + if (cache_config.num_layers == 0) { + throw std::runtime_error("DynamicCache: num_layers must be specified in CacheConfig"); + } +} + +DynamicCache::DynamicCache(size_t num_layers, size_t max_position_embeddings) + : cache_config_(CacheConfig(CacheType::DYNAMIC, num_layers, max_position_embeddings)), + layers_(num_layers) { + if (num_layers == 0) { + throw std::runtime_error("DynamicCache: num_layers must be greater than 0"); + } +} + +std::pair DynamicCache::update( + size_t layer_idx, + const infinicore::Tensor &k_new, + const infinicore::Tensor &v_new) { + if (layer_idx >= layers_.size()) { + SPDLOG_ERROR("DynamicCache::update: layer_idx {} out of range (num_layers: {})", + layer_idx, layers_.size()); + throw std::runtime_error("DynamicCache: layer_idx out of range"); + } + + // Update the cache for this layer with cache config + return layers_[layer_idx].update(k_new, v_new, cache_config_); +} + +std::pair DynamicCache::update( + const infinicore::Tensor &k_new, + const infinicore::Tensor &v_new) { + return update(0, k_new, v_new); +} + +const CacheConfig &DynamicCache::get_config() const { + return cache_config_; +} + +void DynamicCache::update_config(const CacheConfig &new_config) { + // Check if we need to rebuild + bool need_rebuild = false; + + // Rebuild if number of layers changed + if (new_config.num_layers != cache_config_.num_layers || new_config.initial_batch_size != cache_config_.initial_batch_size) { + need_rebuild = true; + layers_.resize(new_config.num_layers); + } + + // Rebuild if reset mode is RECREATE + if (new_config.reset_mode == CacheResetMode::RECREATE) { + need_rebuild = true; + } + + // Update configuration + cache_config_ = new_config; + + if (need_rebuild) { + // Clear all layers to force reinitialization on next use + for (auto &layer : layers_) { + layer.initialized = false; + layer.max_capacity = 0; + // Tensors will be recreated when ensure_capacity is called + } + spdlog::info("DynamicCache configuration updated - cache will be rebuilt on next use"); + } else { + spdlog::info("DynamicCache configuration updated: layers={}, initial_capacity={}, growth_factor={}", + new_config.num_layers, new_config.initial_capacity, new_config.growth_factor); + } +} + +size_t DynamicCache::num_layers() const { + return layers_.size(); +} + +size_t DynamicCache::cache_position(size_t layer_idx) const { + if (layer_idx >= layers_.size()) { + throw std::runtime_error("DynamicCache: layer_idx out of range"); + } + if (layers_[layer_idx].cache_positions.empty()) { + return 0; + } + return layers_[layer_idx].cache_positions[0]; +} + +bool DynamicCache::is_initialized() const { + return !layers_.empty() && layers_[0].initialized; +} + +size_t DynamicCache::max_kv_cache_length() const { + return cache_config_.max_kv_cache_length; +} + +void DynamicCache::reset(size_t pos) { + for (auto &layer : layers_) { + std::fill(layer.cache_positions.begin(), layer.cache_positions.end(), pos); + // Note: We don't reset initialized flag or clear the cache tensors + // to avoid reallocation. The cache will be overwritten on next update. + } +} + +KVCacheLayer &DynamicCache::layer(size_t layer_idx) { + if (layer_idx >= layers_.size()) { + throw std::runtime_error("DynamicCache: layer_idx out of range"); + } + return layers_[layer_idx]; +} + +const KVCacheLayer &DynamicCache::layer(size_t layer_idx) const { + if (layer_idx >= layers_.size()) { + throw std::runtime_error("DynamicCache: layer_idx out of range"); + } + return layers_[layer_idx]; +} + +} // namespace infinilm::cache diff --git a/csrc/cache/dynamic_cache/dynamic_cache.hpp b/csrc/cache/dynamic_cache/dynamic_cache.hpp new file mode 100644 index 00000000..7e7f67e7 --- /dev/null +++ b/csrc/cache/dynamic_cache/dynamic_cache.hpp @@ -0,0 +1,166 @@ +#pragma once + +#include "infinicore/context/context.hpp" +#include "infinicore/device.hpp" +#include "infinicore/tensor.hpp" + +#include "../cache_config.hpp" +#include "../cache_interface.hpp" + +#include +#include +#include +#include +#include + +#include + +namespace infinilm::cache { + +/** + * @brief Single layer's KV cache for incremental decoding + * + * Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim] + * Similar to DynamicLayer in Python cache_utils.py + * + * This represents a single layer's cache within a model-level cache container. + */ +struct KVCacheLayer { + infinicore::Tensor k_cache; // [batch_size, n_kv_head, capacity, head_dim] + infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim] + std::vector cache_positions; // Current position in cache + size_t max_capacity; // Maximum capacity of cache + size_t initial_capacity; // Initial capacity from config + size_t initial_batch_size; // Initial batch size from config + float growth_factor; // Growth factor for dynamic resizing + bool initialized; // Whether cache has been initialized + + /** + * @brief Default constructor + */ + KVCacheLayer(); + + /** + * @brief Initialize or update cache capacity with config parameters + * @param batch_size Current batch size + * @param num_kv_heads Number of key-value heads + * @param head_dim Head dimension + * @param seq_len Sequence length of new tokens + * @param dtype Data type + * @param device Device + * @param cache_config Cache configuration parameters + */ + void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len, + infinicore::DataType dtype, const infinicore::Device &device, + const CacheConfig &cache_config); + + /** + * @brief Update cache with new key and value states + * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim] + * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim] + * @param cache_config Cache configuration for capacity management + * @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim] + */ + std::pair update( + const infinicore::Tensor &k_new, + const infinicore::Tensor &v_new, + const CacheConfig &cache_config); +}; + +/** + * @brief Model-level KV cache container (similar to DynamicCache in Python) + * + * Stores a list of KVCacheLayer objects, one per model layer. + * This aligns with Python backend's DynamicCache architecture. + */ +class DynamicCache : public CacheInterface { +public: + /** + * @brief Construct DynamicCache with cache configuration + * @param cache_config Cache configuration parameters + */ + explicit DynamicCache(const CacheConfig &cache_config); + + /** + * @brief Construct DynamicCache with specified number of layers + * + * @param num_layers Number of model layers (creates one cache layer per model layer) + * @param max_position_embeddings Maximum position embeddings (used for initial capacity) + */ + DynamicCache(size_t num_layers, size_t max_position_embeddings = 4096); + + /** + * @brief Update cache with new key and value states for a specific layer + */ + std::pair update( + size_t layer_idx, + const infinicore::Tensor &k_new, + const infinicore::Tensor &v_new) override; + + /** + * @brief Update cache with new key and value states (convenience method without layer_idx) + * This is used when the cache is accessed directly without layer information + * + * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim] + * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim] + * @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim] + * + * Note: This assumes layer_idx=0. For multi-layer models, use update(layer_idx, k_new, v_new) instead. + */ + std::pair update( + const infinicore::Tensor &k_new, + const infinicore::Tensor &v_new) override; + + /** + * @brief Get cache configuration + */ + const CacheConfig &get_config() const override; + + /** + * @brief Update cache configuration (for dynamic reconfiguration) + */ + void update_config(const CacheConfig &new_config) override; + + /** + * @brief Get the number of layers in this cache + */ + size_t num_layers() const override; + + /** + * @brief Get cache position for a specific layer + */ + size_t cache_position(size_t layer_idx) const override; + + /** + * @brief Check if cache is initialized + */ + bool is_initialized() const override; + + /** + * @brief Get max position embeddings (used for initial capacity) + */ + size_t max_kv_cache_length() const; + + /** + * @brief Reset cache for all layers to a specific position + * This should be called when starting a new generation sequence or resetting to a specific position + * @param pos Position to reset to (defaults to 0) + */ + void reset(size_t pos = 0) override; + + /** + * @brief Access a specific layer's cache (for advanced usage) + */ + KVCacheLayer &layer(size_t layer_idx); + + /** + * @brief Access a specific layer's cache (const version) + */ + const KVCacheLayer &layer(size_t layer_idx) const; + +private: + CacheConfig cache_config_; + std::vector layers_; +}; + +} // namespace infinilm::cache diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp deleted file mode 100644 index 4e50297e..00000000 --- a/csrc/cache/kv_cache.hpp +++ /dev/null @@ -1,369 +0,0 @@ -#pragma once - -#include "infinicore/context/context.hpp" -#include "infinicore/device.hpp" -#include "infinicore/tensor.hpp" - -#include "cache_config.hpp" - -#include -#include -#include -#include -#include - -#include - -namespace infinilm::cache { - -/** - * @brief Single layer's KV cache for incremental decoding - * - * Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim] - * Similar to DynamicLayer in Python cache_utils.py - * - * This represents a single layer's cache within a model-level cache container. - */ -struct KVCacheLayer { - infinicore::Tensor k_cache; // [batch_size, n_kv_head, capacity, head_dim] - infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim] - std::vector cache_positions; // Current position in cache - size_t max_capacity; // Maximum capacity of cache - size_t initial_capacity; // Initial capacity from config - size_t initial_batch_size; // Initial batch size from config - float growth_factor; // Growth factor for dynamic resizing - bool initialized; // Whether cache has been initialized - - KVCacheLayer() : max_capacity(0), initial_capacity(4096), initial_batch_size(1), - growth_factor(2.0f), initialized(false) {} - - /** - * @brief Initialize or update cache capacity with config parameters - * @param batch_size Current batch size - * @param num_kv_heads Number of key-value heads - * @param head_dim Head dimension - * @param seq_len Sequence length of new tokens - * @param dtype Data type - * @param device Device - * @param cache_config Cache configuration parameters - */ - void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len, - infinicore::DataType dtype, const infinicore::Device &device, - const CacheConfig &cache_config) { - size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); }); - - // VALIDATION: Verify input parameters - if (num_kv_heads == 0 || head_dim == 0 || seq_len == 0) { - SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Invalid parameters - num_kv_heads: {}, head_dim: {}, seq_len: {}", - num_kv_heads, head_dim, seq_len); - throw std::runtime_error("KV cache ensure_capacity: invalid parameters"); - } - - // Store config parameters on first initialization - if (!initialized) { - initial_capacity = cache_config.initial_capacity; - initial_batch_size = cache_config.initial_batch_size; - growth_factor = cache_config.growth_factor; - } - - // Lazy initialization - if (!initialized) { - // Use max of required capacity and initial capacity from config - max_capacity = std::max(required_capacity, initial_capacity); - - // Use max of current batch size and initial batch size from config - size_t alloc_batch_size = std::max(batch_size, initial_batch_size); - - k_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim}, - dtype, device); - v_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim}, - dtype, device); - cache_positions = std::vector(alloc_batch_size, 0); - initialized = true; - - spdlog::debug("Initialized KV cache with batch_size={}, capacity={} (config: initial_batch={}, initial_capacity={})", - alloc_batch_size, max_capacity, initial_batch_size, initial_capacity); - - // VALIDATION: Verify cache was created correctly - if (k_cache->shape()[0] != alloc_batch_size || k_cache->shape()[1] != num_kv_heads || k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) { - SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization"); - throw std::runtime_error("KV cache initialization: shape mismatch"); - } - } - // Grow cache if needed using growth factor from config - else if (required_capacity > max_capacity) { - if (!cache_config.allow_expand) { - SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache expansion not allowed by config"); - throw std::runtime_error("KV cache expansion not allowed"); - } - // Calculate new capacity using growth factor - size_t new_capacity = static_cast( - std::max(static_cast(max_capacity) * growth_factor, - static_cast(required_capacity + max_capacity))); - - // Ensure we don't exceed max_position_embeddings if specified - if (cache_config.max_kv_cache_length != 0) { - new_capacity = std::min(new_capacity, cache_config.max_kv_cache_length); - } - - // Ensure we grow by at least some minimum amount - size_t min_growth = 256; - if (new_capacity - max_capacity < min_growth) { - new_capacity = max_capacity + min_growth; - } - - size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]); - if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) { - throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache."); - } - if (new_batch_size > cache_positions.size()) { - cache_positions.resize(new_batch_size, 0); - } - - auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim}, - dtype, device); - auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim}, - dtype, device); - - spdlog::debug("Growing KV cache from capacity {} to {} (growth_factor={})", - max_capacity, new_capacity, growth_factor); - - // Copy existing cache data - for (size_t b = 0; b < new_batch_size; ++b) { - size_t cache_position = cache_positions[b]; - if (cache_position > 0) { - auto k_slice = k_cache->narrow({{0, b, 1}, {2, 0, cache_position}}); - auto v_slice = v_cache->narrow({{0, b, 1}, {2, 0, cache_position}}); - k_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(k_slice); - v_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(v_slice); - } - } - - k_cache = k_new; - v_cache = v_new; - max_capacity = new_capacity; - - // VALIDATION: Verify cache was grown correctly - if (k_cache->shape()[2] != new_capacity) { - SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch"); - throw std::runtime_error("KV cache growth: capacity mismatch"); - } - } - - // VALIDATION: Final check that capacity is sufficient - if (required_capacity > max_capacity) { - SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth"); - throw std::runtime_error("KV cache ensure_capacity: capacity insufficient"); - } - } - - /** - * @brief Update cache with new key and value states - * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim] - * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim] - * @param cache_config Cache configuration for capacity management - * @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim] - */ - std::pair update( - const infinicore::Tensor &k_new, - const infinicore::Tensor &v_new, - const CacheConfig &cache_config) { - if (k_new->ndim() != 4 || v_new->ndim() != 4) { - throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors"); - } - size_t batch_size = k_new->shape()[0]; - size_t num_kv_heads = k_new->shape()[1]; - size_t seq_len = k_new->shape()[2]; - size_t head_dim = k_new->shape()[3]; - - // Ensure capacity with cache config - ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len, - k_new->dtype(), k_new->device(), cache_config); - - // Copy new k/v into cache at current position - bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin()); - if (all_equal) { - auto cache_position = cache_positions[0]; - - auto k_dst = k_cache->narrow({{2, cache_position, seq_len}}); - auto v_dst = v_cache->narrow({{2, cache_position, seq_len}}); - k_dst->copy_from(k_new); - v_dst->copy_from(v_new); - - // Update position - cache_position += seq_len; - for (size_t b = 0; b < batch_size; ++b) { - cache_positions[b] = cache_position; - } - - // Return the total cache up to current position - auto k_total = k_cache->narrow({{2, 0, cache_position}}); - auto v_total = v_cache->narrow({{2, 0, cache_position}}); - - return std::make_pair(k_total, v_total); - } else { - throw std::runtime_error("KVCache update: cache positions must be equal among a batch."); - } - } -}; - -/** - * @brief Model-level KV cache container (similar to DynamicCache in Python) - * - * Stores a list of KVCacheLayer objects, one per model layer. - * This aligns with Python backend's DynamicCache architecture. - */ -class DynamicCache { -public: - /** - * @brief Construct DynamicCache with cache configuration - * @param cache_config Cache configuration parameters - */ - DynamicCache(const CacheConfig &cache_config) - : cache_config_(cache_config), layers_(cache_config.num_layers) { - if (cache_config.num_layers == -1) { - throw std::runtime_error("DynamicCache: num_layers must be specified in CacheConfig"); - } - } - - /** - * @brief Construct DynamicCache with specified number of layers - * - * @param num_layers Number of model layers (creates one cache layer per model layer) - * @param max_position_embeddings Maximum position embeddings (used for initial capacity) - */ - DynamicCache(size_t num_layers, size_t max_position_embeddings = 4096) - : cache_config_(CacheConfig(CacheType::DYNAMIC, num_layers, max_position_embeddings)), layers_(num_layers) {} - - /** - * @brief Update cache with new key and value states for a specific layer - */ - std::pair update( - size_t layer_idx, - const infinicore::Tensor &k_new, - const infinicore::Tensor &v_new) { - if (layer_idx >= layers_.size()) { - SPDLOG_ERROR("DynamicCache::update: layer_idx {} out of range (num_layers: {})", - layer_idx, layers_.size()); - throw std::runtime_error("DynamicCache: layer_idx out of range"); - } - - // Update the cache for this layer with cache config - return layers_[layer_idx].update(k_new, v_new, cache_config_); - } - - /** - * @brief Update cache with new key and value states (convenience method without layer_idx) - * This is used when the cache is accessed directly without layer information - * - * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim] - * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim] - * @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim] - * - * Note: This assumes layer_idx=0. For multi-layer models, use update(layer_idx, k_new, v_new) instead. - */ - std::pair update( - const infinicore::Tensor &k_new, - const infinicore::Tensor &v_new) { - return update(0, k_new, v_new); - } - - /** - * @brief Get cache configuration - */ - const CacheConfig &get_config() const { return cache_config_; } - - /** - * @brief Update cache configuration (for dynamic reconfiguration) - */ - void update_config(const CacheConfig &new_config) { - // Check if we need to rebuild - bool need_rebuild = false; - - // Rebuild if number of layers changed - if (new_config.num_layers != cache_config_.num_layers || new_config.initial_batch_size != cache_config_.initial_batch_size) { - need_rebuild = true; - layers_.resize(new_config.num_layers); - } - - // Rebuild if reset mode is RECREATE - if (new_config.reset_mode == CacheResetMode::RECREATE) { - need_rebuild = true; - } - - // Update configuration - cache_config_ = new_config; - - if (need_rebuild) { - // Clear all layers to force reinitialization on next use - for (auto &layer : layers_) { - layer.initialized = false; - layer.max_capacity = 0; - // Tensors will be recreated when ensure_capacity is called - } - spdlog::info("DynamicCache configuration updated - cache will be rebuilt on next use"); - } else { - spdlog::info("DynamicCache configuration updated: layers={}, initial_capacity={}, growth_factor={}", - new_config.num_layers, new_config.initial_capacity, new_config.growth_factor); - } - } - - /** - * @brief Get the number of layers in this cache - */ - size_t num_layers() const { return layers_.size(); } - - /** - * @brief Get cache position for a specific layer - */ - size_t cache_position(size_t layer_idx) const { - if (layer_idx >= layers_.size()) { - throw std::runtime_error("DynamicCache: layer_idx out of range"); - } - if (layers_[layer_idx].cache_positions.empty()) { - return 0; - } - return layers_[layer_idx].cache_positions[0]; // All batch items should have same position - } - - /** - * @brief Get max position embeddings (used for initial capacity) - */ - size_t max_kv_cache_length() const { return cache_config_.max_kv_cache_length; } - - /** - * @brief Reset cache for all layers to a specific position - * This should be called when starting a new generation sequence or resetting to a specific position - * @param pos Position to reset to (defaults to 0) - */ - void reset(size_t pos = 0) { - for (auto &layer : layers_) { - std::fill(layer.cache_positions.begin(), layer.cache_positions.end(), pos); - // Note: We don't reset initialized flag or clear the cache tensors - // to avoid reallocation. The cache will be overwritten on next update. - } - } - - /** - * @brief Access a specific layer's cache (for advanced usage) - */ - KVCacheLayer &layer(size_t layer_idx) { - if (layer_idx >= layers_.size()) { - throw std::runtime_error("DynamicCache: layer_idx out of range"); - } - return layers_[layer_idx]; - } - - const KVCacheLayer &layer(size_t layer_idx) const { - if (layer_idx >= layers_.size()) { - throw std::runtime_error("DynamicCache: layer_idx out of range"); - } - return layers_[layer_idx]; - } - -private: - CacheConfig cache_config_; - std::vector layers_; -}; - -} // namespace infinilm::cache diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 8740f7a3..e8faa392 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -177,7 +177,11 @@ void RankWorker::thread_loop() { // Initialize device & model outside of holding the main mutex to avoid blocking callers. infinicore::context::setDevice(rank_info_.device); - cache_ptr_ = std::make_shared(pending_cache_config_); + cache_ptr_ = cache::CacheInterface::create(pending_cache_config_); + spdlog::info("[{}] Created {} cache with {} layers", + info(), + cache_ptr_->get_config().type == cache::CacheType::DYNAMIC ? "Dynamic" : "Paged", + cache_ptr_->num_layers()); // Create model using factory (may be expensive) model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, cache_ptr_); @@ -269,15 +273,8 @@ void RankWorker::thread_loop() { } } else if (local_cmd == Command::RESET_CACHE) { try { - // Option 1: Use model's reset_cache if it handles cache model_->reset_cache(local_reset_pos); - // Option 2: Reset cache directly if we have access - // if (cache_ptr_ != nullptr) { - // auto* dynamic_cache = static_cast(cache_ptr_); - // dynamic_cache->reset(local_reset_pos); - // } - { std::lock_guard lk(mutex_); job_done_ = true; diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 63e4cef9..265d9c0f 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -63,7 +63,7 @@ class RankWorker { std::any model_config_; distributed::RankInfo rank_info_; std::shared_ptr model_; - std::shared_ptr cache_ptr_; + std::shared_ptr cache_ptr_; // Command for the pending job (protected by mutex_) Command job_cmd_; diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index 8d951f3f..7af7dcfa 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -97,7 +97,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim] auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] - infinilm::cache::DynamicCache *external_cache = static_cast(kv_cache); + infinilm::cache::CacheInterface *external_cache = static_cast(kv_cache); infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim] if (external_cache != nullptr) { diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index bdc98d23..b6ef0ace 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -1,6 +1,6 @@ #pragma once -#include "../../cache/kv_cache.hpp" +#include "../../cache/cache.hpp" #include "../../engine/distributed/distributed.hpp" #include "../../layers/fused_linear.hpp" #include "llama_config.hpp" diff --git a/csrc/models/llama/llama_model.hpp b/csrc/models/llama/llama_model.hpp index f02a9f7f..a5aad352 100644 --- a/csrc/models/llama/llama_model.hpp +++ b/csrc/models/llama/llama_model.hpp @@ -1,6 +1,6 @@ #pragma once -#include "../../cache/kv_cache.hpp" +#include "../../cache/cache.hpp" #include "llama_config.hpp" #include "llama_decoder_layer.hpp" @@ -79,7 +79,7 @@ class LlamaModel : public infinicore::nn::Module { * @brief Set external cache for the model * @param cache Pointer to external cache (managed by CacheManager) */ - void set_external_cache(std::shared_ptr cache) { + void set_external_cache(std::shared_ptr cache) { external_cache_ = cache.get(); } @@ -102,7 +102,7 @@ class LlamaModel : public infinicore::nn::Module { // Mutable because it's not part of the model's learned parameters, // but needs to persist across forward calls for incremental decoding mutable std::unique_ptr internal_cache_; - cache::DynamicCache *external_cache_ = nullptr; + cache::CacheInterface *external_cache_ = nullptr; }; } // namespace infinilm::models::llama diff --git a/csrc/models/model_factory.cpp b/csrc/models/model_factory.cpp index f65315ac..1615b159 100644 --- a/csrc/models/model_factory.cpp +++ b/csrc/models/model_factory.cpp @@ -5,7 +5,7 @@ namespace infinilm { std::shared_ptr InfinilmModelFactory::createModel( const std::any &config, engine::distributed::RankInfo rank_info, - std::shared_ptr cache_ptr) { + std::shared_ptr cache_ptr) { if (config.type() == typeid(models::llama::LlamaConfig)) { const auto &llama_config = std::any_cast(config); diff --git a/csrc/models/model_factory.hpp b/csrc/models/model_factory.hpp index 029c33b5..e6fb1a48 100644 --- a/csrc/models/model_factory.hpp +++ b/csrc/models/model_factory.hpp @@ -7,6 +7,6 @@ namespace infinilm { class InfinilmModelFactory { public: - static std::shared_ptr createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr cache_ptr = nullptr); + static std::shared_ptr createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr cache_ptr = nullptr); }; } // namespace infinilm diff --git a/csrc/pybind11/models/llama.hpp b/csrc/pybind11/models/llama.hpp index 3ce39ebe..b8c0e3f1 100644 --- a/csrc/pybind11/models/llama.hpp +++ b/csrc/pybind11/models/llama.hpp @@ -1,6 +1,6 @@ #pragma once -#include "../../cache/kv_cache.hpp" +#include "../../cache/cache.hpp" #include "../../models/debug_utils/hooks.hpp" #include "../../models/llama/llama.hpp" #include "../../models/llama/llama_attention.hpp" @@ -65,7 +65,8 @@ inline void bind_llama(py::module &m) { .def_readwrite("pretraining_tp", &LlamaConfig::pretraining_tp) .def_readwrite("name_or_path", &LlamaConfig::name_or_path) .def_readwrite("pad_token_id", &LlamaConfig::pad_token_id) - .def_property("bos_token_id", [](const LlamaConfig &self) { + .def_property( + "bos_token_id", [](const LlamaConfig &self) { // Always return as list to match Python config format return py::cast(self.bos_token_id); }, [](LlamaConfig &self, py::object value) { // Accept both single int and list @@ -76,7 +77,8 @@ inline void bind_llama(py::module &m) { } else { throw py::type_error("bos_token_id must be int or list of ints"); } }) - .def_property("eos_token_id", [](const LlamaConfig &self) { + .def_property( + "eos_token_id", [](const LlamaConfig &self) { // Always return as list to match Python config format return py::cast(self.eos_token_id); }, [](LlamaConfig &self, py::object value) { // Accept both single int and list @@ -133,7 +135,8 @@ inline void bind_llama(py::module &m) { } return result; }) - .def("get_parameter", [](const LlamaForCausalLM &model, const std::string &name) { + .def( + "get_parameter", [](const LlamaForCausalLM &model, const std::string &name) { // Get actual tensor parameter by name auto state_dict = model.state_dict(); auto it = state_dict.find(name); @@ -143,7 +146,8 @@ inline void bind_llama(py::module &m) { return tensor; } throw std::runtime_error("Parameter '" + name + "' not found in model"); }, py::arg("name")) - .def("load_state_dict", [](LlamaForCausalLM &model, py::dict state_dict) { + .def( + "load_state_dict", [](LlamaForCausalLM &model, py::dict state_dict) { // Convert Python dict to C++ state_dict std::unordered_map cpp_state_dict; for (auto item : state_dict) { @@ -160,10 +164,12 @@ inline void bind_llama(py::module &m) { } model.load_state_dict(cpp_state_dict); }, py::arg("state_dict")) .def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal) - .def("reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) { + .def( + "reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) { // Reset the internal cache to prevent state from persisting between generations model.model().reset_cache(pos); }, py::arg("pos") = 0, "Reset the internal cache to a specific position (clears state between generations)") - .def("forward", [](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_cache = py::none()) { + .def( + "forward", [](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_cache = py::none()) { // Helper to extract C++ tensor from Python InfiniCore tensor auto get_tensor = [](py::object obj) -> infinicore::Tensor { // If it's already a Python InfiniCore tensor wrapper, extract underlying From 9da90b5a224183f33877afcfe6fbe6ac4f6ae895 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 12 Dec 2025 16:20:59 +0800 Subject: [PATCH 2/2] issue/125 - renamed cache interface to cache --- csrc/cache/cache_factory.cpp | 2 +- csrc/cache/cache_interface.hpp | 6 +++--- csrc/cache/dynamic_cache/dynamic_cache.hpp | 2 +- csrc/engine/rank_worker.cpp | 2 +- csrc/engine/rank_worker.hpp | 2 +- csrc/models/llama/llama_attention.cpp | 2 +- csrc/models/llama/llama_model.hpp | 4 ++-- csrc/models/model_factory.cpp | 2 +- csrc/models/model_factory.hpp | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) diff --git a/csrc/cache/cache_factory.cpp b/csrc/cache/cache_factory.cpp index c4d6fd91..57486cea 100644 --- a/csrc/cache/cache_factory.cpp +++ b/csrc/cache/cache_factory.cpp @@ -4,7 +4,7 @@ namespace infinilm::cache { -std::shared_ptr CacheInterface::create(const CacheConfig &config) { +std::shared_ptr Cache::create(const CacheConfig &config) { switch (config.type) { case CacheType::DYNAMIC: return std::make_shared(config); diff --git a/csrc/cache/cache_interface.hpp b/csrc/cache/cache_interface.hpp index fa21c4a7..b90d9257 100644 --- a/csrc/cache/cache_interface.hpp +++ b/csrc/cache/cache_interface.hpp @@ -11,9 +11,9 @@ namespace infinilm::cache { * @brief Abstract interface for KV cache implementations * This allows different cache types (Dynamic, Paged, etc.) to be used interchangeably */ -class CacheInterface { +class Cache { public: - virtual ~CacheInterface() = default; + virtual ~Cache() = default; /** * @brief Update cache with new key and value states @@ -72,7 +72,7 @@ class CacheInterface { /** * @brief Factory method to create cache based on configuration */ - static std::shared_ptr create(const CacheConfig &config); + static std::shared_ptr create(const CacheConfig &config); }; } // namespace infinilm::cache diff --git a/csrc/cache/dynamic_cache/dynamic_cache.hpp b/csrc/cache/dynamic_cache/dynamic_cache.hpp index 7e7f67e7..93d5af69 100644 --- a/csrc/cache/dynamic_cache/dynamic_cache.hpp +++ b/csrc/cache/dynamic_cache/dynamic_cache.hpp @@ -73,7 +73,7 @@ struct KVCacheLayer { * Stores a list of KVCacheLayer objects, one per model layer. * This aligns with Python backend's DynamicCache architecture. */ -class DynamicCache : public CacheInterface { +class DynamicCache : public Cache { public: /** * @brief Construct DynamicCache with cache configuration diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index e8faa392..0c23ec50 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -177,7 +177,7 @@ void RankWorker::thread_loop() { // Initialize device & model outside of holding the main mutex to avoid blocking callers. infinicore::context::setDevice(rank_info_.device); - cache_ptr_ = cache::CacheInterface::create(pending_cache_config_); + cache_ptr_ = cache::Cache::create(pending_cache_config_); spdlog::info("[{}] Created {} cache with {} layers", info(), cache_ptr_->get_config().type == cache::CacheType::DYNAMIC ? "Dynamic" : "Paged", diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 265d9c0f..92cecb75 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -63,7 +63,7 @@ class RankWorker { std::any model_config_; distributed::RankInfo rank_info_; std::shared_ptr model_; - std::shared_ptr cache_ptr_; + std::shared_ptr cache_ptr_; // Command for the pending job (protected by mutex_) Command job_cmd_; diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index 7af7dcfa..fa03a1ef 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -97,7 +97,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim] auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] - infinilm::cache::CacheInterface *external_cache = static_cast(kv_cache); + infinilm::cache::Cache *external_cache = static_cast(kv_cache); infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim] if (external_cache != nullptr) { diff --git a/csrc/models/llama/llama_model.hpp b/csrc/models/llama/llama_model.hpp index a5aad352..a249c967 100644 --- a/csrc/models/llama/llama_model.hpp +++ b/csrc/models/llama/llama_model.hpp @@ -79,7 +79,7 @@ class LlamaModel : public infinicore::nn::Module { * @brief Set external cache for the model * @param cache Pointer to external cache (managed by CacheManager) */ - void set_external_cache(std::shared_ptr cache) { + void set_external_cache(std::shared_ptr cache) { external_cache_ = cache.get(); } @@ -102,7 +102,7 @@ class LlamaModel : public infinicore::nn::Module { // Mutable because it's not part of the model's learned parameters, // but needs to persist across forward calls for incremental decoding mutable std::unique_ptr internal_cache_; - cache::CacheInterface *external_cache_ = nullptr; + cache::Cache *external_cache_ = nullptr; }; } // namespace infinilm::models::llama diff --git a/csrc/models/model_factory.cpp b/csrc/models/model_factory.cpp index 1615b159..32e1ee10 100644 --- a/csrc/models/model_factory.cpp +++ b/csrc/models/model_factory.cpp @@ -5,7 +5,7 @@ namespace infinilm { std::shared_ptr InfinilmModelFactory::createModel( const std::any &config, engine::distributed::RankInfo rank_info, - std::shared_ptr cache_ptr) { + std::shared_ptr cache_ptr) { if (config.type() == typeid(models::llama::LlamaConfig)) { const auto &llama_config = std::any_cast(config); diff --git a/csrc/models/model_factory.hpp b/csrc/models/model_factory.hpp index e6fb1a48..f909163b 100644 --- a/csrc/models/model_factory.hpp +++ b/csrc/models/model_factory.hpp @@ -7,6 +7,6 @@ namespace infinilm { class InfinilmModelFactory { public: - static std::shared_ptr createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr cache_ptr = nullptr); + static std::shared_ptr createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr cache_ptr = nullptr); }; } // namespace infinilm