diff --git a/CMakeLists.txt b/CMakeLists.txt index 263118e..575e7bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,13 +24,6 @@ option(DEBUG "Build with debugging information" OFF) # ==== External Dependencies ==== -# OpenGL -find_package(OpenGL REQUIRED) -if(NOT OpenGL_FOUND) - message(FATAL_ERROR "OpenGL not found") -endif() -message(STATUS "Found OpenGL: ${OpenGL_INCLUDE_DIR}") - # OpenSSL find_package(OpenSSL REQUIRED) if(NOT OpenSSL_FOUND) @@ -112,7 +105,6 @@ target_compile_definitions(kolosal_lib PUBLIC target_include_directories(kolosal_lib PUBLIC ${IMGUI_DIR} ${IMGUI_DIR}/backends - ${OpenGL_INCLUDE_DIR} ${EXTERNAL_DIR}/glad/include ${EXTERNAL_DIR}/icons ${EXTERNAL_DIR}/nlohmann @@ -131,7 +123,6 @@ target_include_directories(kolosal_lib PUBLIC # Platform-specific library dependencies if(WIN32) target_link_libraries(kolosal_lib PUBLIC - ${OpenGL_LIBRARIES} glad nfd Dwmapi @@ -153,12 +144,11 @@ if(WIN32) ) else() target_link_libraries(kolosal_lib PUBLIC - ${OpenGL_LIBRARIES} glad nfd OpenSSL::SSL ) - + target_compile_definitions(kolosal_lib PUBLIC IMGUI_IMPL_OPENGL_LOADER_GLAD ${FONT_DEFINITIONS} diff --git a/external/genta-personal/bin/InferenceEngineLib.dll b/external/genta-personal/bin/InferenceEngineLib.dll index 0ae0518..af45d73 100644 Binary files a/external/genta-personal/bin/InferenceEngineLib.dll and b/external/genta-personal/bin/InferenceEngineLib.dll differ diff --git a/external/genta-personal/bin/InferenceEngineLibVulkan.dll b/external/genta-personal/bin/InferenceEngineLibVulkan.dll index d8e404b..105566d 100644 Binary files a/external/genta-personal/bin/InferenceEngineLibVulkan.dll and b/external/genta-personal/bin/InferenceEngineLibVulkan.dll differ diff --git a/external/genta-personal/include/types.h b/external/genta-personal/include/types.h index 6771f0a..fb187d7 100644 --- a/external/genta-personal/include/types.h +++ b/external/genta-personal/include/types.h @@ -69,7 +69,7 @@ struct LoadingParameters bool warmup = false; int n_parallel = 1; int n_gpu_layers = 100; - int n_batch = 4096; + int n_batch = 256; }; #endif // TYPES_H \ No newline at end of file diff --git a/include/config.hpp b/include/config.hpp index 9686f1b..d29ea58 100644 --- a/include/config.hpp +++ b/include/config.hpp @@ -2,6 +2,7 @@ #include +#define APP_VERSION "0.1.8" // TODO: Need to refactor this to use json file that is modifiable by the user in realtime // Set up a system to save and load the settings from a json file @@ -132,4 +133,5 @@ namespace Config constexpr float INPUT_HEIGHT = 100.0F; constexpr float CHAT_WINDOW_CONTENT_WIDTH = 750.0F; constexpr float TITLE_BAR_HEIGHT = 50.0F; + constexpr float FOOTER_HEIGHT = 22.0F; } // namespace Config \ No newline at end of file diff --git a/include/model/gguf_reader.hpp b/include/model/gguf_reader.hpp new file mode 100644 index 0000000..c2267ed --- /dev/null +++ b/include/model/gguf_reader.hpp @@ -0,0 +1,524 @@ +#ifndef GGUF_READER_H +#define GGUF_READER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Structure to hold the extracted model parameters +struct GGUFModelParams { + uint64_t hidden_size = 0; // Mapped from embedding_length + uint32_t attention_heads = 0; // Mapped from attention.head_count + uint32_t hidden_layers = 0; // Mapped from block_count + uint32_t kv_heads = 0; // Mapped from attention.head_count_kv or head_count +}; + +// Abstract base class for data sources +class DataSource { +public: + virtual ~DataSource() = default; + virtual bool read(char* buffer, size_t size) = 0; + virtual bool seek(size_t position) = 0; + virtual bool eof() const = 0; + virtual size_t tell() = 0; // Removed the const qualifier since tellg() is non-const. +}; + +// File-based data source +class FileDataSource : public DataSource { +public: + FileDataSource(const std::string& filename) { + file.open(filename, std::ios::binary); + if (!file) + throw std::runtime_error("Failed to open file: " + filename); + } + + ~FileDataSource() override { + if (file.is_open()) + file.close(); + } + + bool read(char* buffer, size_t size) override { + file.read(buffer, size); + return file.good() || (file.eof() && file.gcount() > 0); + } + + bool seek(size_t position) override { + file.seekg(position); + return file.good(); + } + + bool eof() const override { + return file.eof(); + } + + size_t tell() override { + return file.tellg(); + } + +private: + std::ifstream file; +}; + +// CURL callback data structure +struct CurlBuffer { + char* buffer; + size_t size; + size_t pos; + bool* abort_download; +}; + +// CURL-based URL data source +class UrlDataSource : public DataSource { +public: + UrlDataSource(const std::string& url) : url(url), currentPos(0), abortDownload(false) { + curl = curl_easy_init(); + if (!curl) + throw std::runtime_error("Failed to initialize curl"); + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &writeData); + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, ProgressCallback); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &abortDownload); + + downloadedData.resize(BUFFER_SIZE); + bufferSize = 0; + bufferPos = 0; + } + + ~UrlDataSource() override { + if (curl) + curl_easy_cleanup(curl); + } + + bool read(char* buffer, size_t size) override { + while (bufferPos + size > bufferSize) { + if (bufferPos >= bufferSize) { + bufferSize = 0; + bufferPos = 0; + } + + if (bufferPos > 0 && bufferSize > bufferPos) { + memmove(&downloadedData[0], &downloadedData[bufferPos], bufferSize - bufferPos); + bufferSize -= bufferPos; + bufferPos = 0; + } + + writeData.buffer = &downloadedData[bufferSize]; + writeData.size = downloadedData.size() - bufferSize; + writeData.pos = 0; + writeData.abort_download = &abortDownload; + + std::string range = std::to_string(currentPos + bufferSize) + "-" + + std::to_string(currentPos + bufferSize + CHUNK_SIZE - 1); + curl_easy_setopt(curl, CURLOPT_RANGE, range.c_str()); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK && res != CURLE_WRITE_ERROR) { + std::cerr << "curl_easy_perform() failed: " << curl_easy_strerror(res) << std::endl; + return false; + } + + if (writeData.pos == 0) { + _eof = true; + return false; + } + bufferSize += writeData.pos; + } + + // Specify the template type explicitly to avoid macro issues with std::min. + size_t copySize = std::min(size, bufferSize - bufferPos); + memcpy(buffer, &downloadedData[bufferPos], copySize); + bufferPos += copySize; + currentPos += copySize; + + return copySize == size; + } + + bool seek(size_t position) override { + if (position >= currentPos - bufferPos && position < currentPos + (bufferSize - bufferPos)) { + bufferPos = position - (currentPos - bufferPos); + currentPos = position; + return true; + } + bufferSize = 0; + bufferPos = 0; + currentPos = position; + _eof = false; + return true; + } + + bool eof() const override { + return _eof; + } + + size_t tell() override { + return currentPos; + } + + void setAbortFlag() { + abortDownload = true; + } + +private: + static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { + CurlBuffer* data = static_cast(userdata); + if (*(data->abort_download)) + return 0; + size_t bytes = size * nmemb; + size_t available = data->size - data->pos; + if (bytes > available) + bytes = available; + memcpy(data->buffer + data->pos, ptr, bytes); + data->pos += bytes; + return bytes; + } + + static int ProgressCallback(void* clientp, curl_off_t, curl_off_t dlnow, curl_off_t, curl_off_t) { + bool* abort_flag = static_cast(clientp); + return (*abort_flag) ? 1 : 0; + } + + std::string url; + CURL* curl; + CurlBuffer writeData; + std::vector downloadedData; + size_t bufferSize; + size_t bufferPos; + size_t currentPos; + bool abortDownload; + bool _eof = false; + + static constexpr size_t BUFFER_SIZE = 1024 * 1024; // 1MB buffer + static constexpr size_t CHUNK_SIZE = 256 * 1024; // 256KB chunk size +}; + +class GGUFMetadataReader { +public: + // GGUF metadata types + enum class GGUFType : uint32_t { + UINT8 = 0, + INT8 = 1, + UINT16 = 2, + INT16 = 3, + UINT32 = 4, + INT32 = 5, + FLOAT32 = 6, + BOOL = 7, + STRING = 8, + ARRAY = 9, + UINT64 = 10, + INT64 = 11, + FLOAT64 = 12, + MAX_TYPE = 13 + }; + + GGUFMetadataReader() { + curl_global_init(CURL_GLOBAL_ALL); + } + + ~GGUFMetadataReader() { + curl_global_cleanup(); + } + + bool isUrl(const std::string& path) { + return path.substr(0, 7) == "http://" || path.substr(0, 8) == "https://"; + } + + std::optional readModelParams(const std::string& path, bool verbose = false) { + std::unique_ptr source; + try { + if (isUrl(path)) { + if (verbose) + std::cout << "Reading from URL: " << path << std::endl; + source = std::make_unique(path); + } + else { + if (verbose) + std::cout << "Reading from file: " << path << std::endl; + source = std::make_unique(path); + } + + uint32_t magic; + if (!source->read(reinterpret_cast(&magic), sizeof(magic))) + throw std::runtime_error("Failed to read magic number"); + if (magic != 0x46554747) { + std::cerr << "Invalid GGUF file format. Magic number: " + << std::hex << magic << std::dec << std::endl; + return std::nullopt; + } + + uint32_t version; + if (!source->read(reinterpret_cast(&version), sizeof(version))) + throw std::runtime_error("Failed to read version"); + if (version > 3) { + std::cerr << "Unsupported GGUF version: " << version << std::endl; + return std::nullopt; + } + if (verbose) + std::cout << "GGUF version: " << version << std::endl; + + uint64_t tensorCount = 0; + if (version >= 1) { + if (!source->read(reinterpret_cast(&tensorCount), sizeof(tensorCount))) + throw std::runtime_error("Failed to read tensor count"); + if (verbose) + std::cout << "Tensor count: " << tensorCount << std::endl; + } + + uint64_t metadataCount; + if (!source->read(reinterpret_cast(&metadataCount), sizeof(metadataCount))) + throw std::runtime_error("Failed to read metadata count"); + if (verbose) + std::cout << "Metadata count: " << metadataCount << std::endl; + + const std::vector suffixes = { + ".attention.head_count", + ".attention.head_count_kv", + ".block_count", + ".embedding_length" + }; + + GGUFModelParams params; + std::unordered_map foundParams; + std::vector allKeys; + + for (uint64_t i = 0; i < metadataCount && !source->eof(); ++i) { + std::string key; + try { + key = readString(source.get()); + allKeys.push_back(key); + } + catch (const std::exception& e) { + throw std::runtime_error(std::string("Failed to read key: ") + e.what()); + } + + uint32_t typeVal; + if (!source->read(reinterpret_cast(&typeVal), sizeof(typeVal))) + throw std::runtime_error("Failed to read metadata type for key: " + key); + if (typeVal >= static_cast(GGUFType::MAX_TYPE)) + throw std::runtime_error("Invalid metadata type: " + std::to_string(typeVal) + " for key: " + key); + GGUFType type = static_cast(typeVal); + + if (verbose) + std::cout << "Key: " << key << ", Type: " << static_cast(type) << std::endl; + + bool keyMatched = false; + std::string matchedSuffix; + for (const auto& suffix : suffixes) { + if (endsWith(key, suffix)) { + keyMatched = true; + matchedSuffix = suffix; + break; + } + } + + if (keyMatched) { + if (matchedSuffix == ".attention.head_count" && (type == GGUFType::UINT32 || type == GGUFType::INT32)) { + uint32_t value; + if (!source->read(reinterpret_cast(&value), sizeof(value))) + throw std::runtime_error("Failed to read attention_heads value"); + params.attention_heads = value; + foundParams["attention_heads"] = true; + if (verbose) + std::cout << " Found attention_heads: " << value << " (from key: " << key << ")" << std::endl; + } + else if (matchedSuffix == ".attention.head_count_kv" && (type == GGUFType::UINT32 || type == GGUFType::INT32)) { + uint32_t value; + if (!source->read(reinterpret_cast(&value), sizeof(value))) + throw std::runtime_error("Failed to read kv_heads value"); + params.kv_heads = value; + foundParams["kv_heads"] = true; + if (verbose) + std::cout << " Found kv_heads: " << value << " (from key: " << key << ")" << std::endl; + } + else if (matchedSuffix == ".block_count" && (type == GGUFType::UINT32 || type == GGUFType::INT32)) { + uint32_t value; + if (!source->read(reinterpret_cast(&value), sizeof(value))) + throw std::runtime_error("Failed to read hidden_layers value"); + params.hidden_layers = value; + foundParams["hidden_layers"] = true; + if (verbose) + std::cout << " Found hidden_layers: " << value << " (from key: " << key << ")" << std::endl; + } + else if (matchedSuffix == ".embedding_length") { + if (type == GGUFType::UINT64 || type == GGUFType::INT64) { + uint64_t value; + if (!source->read(reinterpret_cast(&value), sizeof(value))) + throw std::runtime_error("Failed to read hidden_size value (64-bit)"); + params.hidden_size = value; + foundParams["hidden_size"] = true; + if (verbose) + std::cout << " Found hidden_size: " << value << " (from key: " << key << ")" << std::endl; + } + else if (type == GGUFType::UINT32 || type == GGUFType::INT32) { + uint32_t value; + if (!source->read(reinterpret_cast(&value), sizeof(value))) + throw std::runtime_error("Failed to read hidden_size value (32-bit)"); + params.hidden_size = value; + foundParams["hidden_size"] = true; + if (verbose) + std::cout << " Found hidden_size: " << value << " (from key: " << key << ")" << std::endl; + } + else { + skipValue(source.get(), type); + } + } + else { + skipValue(source.get(), type); + } + } + else { + skipValue(source.get(), type); + } + + if (foundParams["attention_heads"] && + foundParams["hidden_layers"] && + foundParams["hidden_size"] && + (foundParams["kv_heads"] || foundParams["attention_heads"])) { + if (isUrl(path)) { + auto urlSource = dynamic_cast(source.get()); + if (urlSource) { + urlSource->setAbortFlag(); + if (verbose) + std::cout << "All required metadata found, aborting download" << std::endl; + } + } + break; + } + } + + if (!foundParams["kv_heads"] && foundParams["attention_heads"]) { + params.kv_heads = params.attention_heads; + foundParams["kv_heads"] = true; + if (verbose) + std::cout << " Using attention_heads as kv_heads: " << params.kv_heads << std::endl; + } + + bool allFound = foundParams["attention_heads"] && + foundParams["hidden_layers"] && + foundParams["hidden_size"]; + + if (!allFound) { + std::cerr << "Failed to find all required model parameters:" << std::endl; + if (!foundParams["attention_heads"]) std::cerr << " Missing: attention_heads (suffix: .attention.head_count)" << std::endl; + if (!foundParams["hidden_layers"]) std::cerr << " Missing: hidden_layers (suffix: .block_count)" << std::endl; + if (!foundParams["hidden_size"]) std::cerr << " Missing: hidden_size (suffix: .embedding_length)" << std::endl; + if (verbose) { + std::cerr << "All keys found:" << std::endl; + for (const auto& key : allKeys) + std::cerr << " " << key << std::endl; + } + return std::nullopt; + } + + return params; + } + catch (const std::exception& e) { + std::cerr << "Error reading GGUF file/URL: " << e.what() << std::endl; + return std::nullopt; + } + } + +private: + bool endsWith(const std::string& str, const std::string& suffix) { + return str.size() >= suffix.size() && + str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; + } + + std::string readString(DataSource* source) { + uint64_t length; + if (!source->read(reinterpret_cast(&length), sizeof(length))) + throw std::runtime_error("Failed to read string length"); + if (length > 1024 * 1024) + throw std::runtime_error("String too long: " + std::to_string(length)); + std::string str(length, '\0'); + if (length > 0) + if (!source->read(&str[0], length)) + throw std::runtime_error("Failed to read string data"); + return str; + } + + void skipArray(DataSource* source, GGUFType elemType) { + uint64_t count; + if (!source->read(reinterpret_cast(&count), sizeof(count))) + throw std::runtime_error("Failed to read array count"); + if (count > 1000000) + throw std::runtime_error("Array count too large: " + std::to_string(count)); + for (uint64_t i = 0; i < count; ++i) + skipValue(source, elemType); + } + + void skipValue(DataSource* source, GGUFType type) { + switch (type) { + case GGUFType::UINT8: + source->seek(source->tell() + sizeof(uint8_t)); + break; + case GGUFType::INT8: + source->seek(source->tell() + sizeof(int8_t)); + break; + case GGUFType::UINT16: + source->seek(source->tell() + sizeof(uint16_t)); + break; + case GGUFType::INT16: + source->seek(source->tell() + sizeof(int16_t)); + break; + case GGUFType::UINT32: + source->seek(source->tell() + sizeof(uint32_t)); + break; + case GGUFType::INT32: + source->seek(source->tell() + sizeof(int32_t)); + break; + case GGUFType::FLOAT32: + source->seek(source->tell() + sizeof(float)); + break; + case GGUFType::BOOL: + source->seek(source->tell() + sizeof(uint8_t)); + break; + case GGUFType::STRING: { + uint64_t length; + if (!source->read(reinterpret_cast(&length), sizeof(length))) + throw std::runtime_error("Failed to read string length for skipping"); + if (length > 1024 * 1024) + throw std::runtime_error("String too long: " + std::to_string(length)); + source->seek(source->tell() + length); + break; + } + case GGUFType::ARRAY: { + uint32_t elemTypeVal; + if (!source->read(reinterpret_cast(&elemTypeVal), sizeof(elemTypeVal))) + throw std::runtime_error("Failed to read array element type"); + if (elemTypeVal >= static_cast(GGUFType::MAX_TYPE)) + throw std::runtime_error("Invalid array element type: " + std::to_string(elemTypeVal)); + GGUFType elemType = static_cast(elemTypeVal); + skipArray(source, elemType); + break; + } + case GGUFType::UINT64: + source->seek(source->tell() + sizeof(uint64_t)); + break; + case GGUFType::INT64: + source->seek(source->tell() + sizeof(int64_t)); + break; + case GGUFType::FLOAT64: + source->seek(source->tell() + sizeof(double)); + break; + default: + throw std::runtime_error("Unknown GGUF type: " + std::to_string(static_cast(type))); + } + } +}; + +#endif // GGUF_READER_H \ No newline at end of file diff --git a/include/model/model.hpp b/include/model/model.hpp index b8d7fd9..e7ad4f0 100644 --- a/include/model/model.hpp +++ b/include/model/model.hpp @@ -10,7 +10,6 @@ using json = nlohmann::json; namespace Model { - // ModelVariant structure remains mostly the same struct ModelVariant { std::string type; std::string path; @@ -19,6 +18,7 @@ namespace Model double downloadProgress; int lastSelected; std::atomic_bool cancelDownload{ false }; + float size; // Default constructor is fine. ModelVariant() = default; @@ -31,7 +31,8 @@ namespace Model , isDownloaded(other.isDownloaded) , downloadProgress(other.downloadProgress) , lastSelected(other.lastSelected) - , cancelDownload(false) // Always initialize to false on copy. + , cancelDownload(false) + , size(other.size) { } @@ -44,7 +45,8 @@ namespace Model isDownloaded = other.isDownloaded; downloadProgress = other.downloadProgress; lastSelected = other.lastSelected; - cancelDownload = false; // Reinitialize the cancellation flag. + cancelDownload = false; + size = other.size; } return *this; } @@ -58,7 +60,8 @@ namespace Model {"downloadLink", v.downloadLink}, {"isDownloaded", v.isDownloaded}, {"downloadProgress", v.downloadProgress}, - {"lastSelected", v.lastSelected} }; + {"lastSelected", v.lastSelected}, + {"size", v.size} }; } inline void from_json(const nlohmann::json& j, ModelVariant& v) @@ -69,6 +72,7 @@ namespace Model j.at("isDownloaded").get_to(v.isDownloaded); j.at("downloadProgress").get_to(v.downloadProgress); j.at("lastSelected").get_to(v.lastSelected); + j.at("size").get_to(v.size); } // Refactored ModelData to use a map of variants @@ -77,10 +81,17 @@ namespace Model std::string name; std::string author; std::map variants; + float_t hidden_size; + float_t attention_heads; + float_t hidden_layers; + float_t kv_heads; // Constructor with no variants - ModelData(const std::string& name = "", const std::string& author = "") - : name(name), author(author) { + ModelData(const std::string& name = "", const std::string& author = "", + const float_t hidden_size = 0, const float_t attention_heads = 0, + const float_t hidden_layers = 0, const float_t kv_heads = 0) + : name(name), author(author), hidden_size(hidden_size), attention_heads(attention_heads) + , hidden_layers(hidden_layers), kv_heads(kv_heads) { } // Add a variant to the model @@ -111,7 +122,11 @@ namespace Model j = nlohmann::json{ {"name", m.name}, {"author", m.author}, - {"variants", m.variants} + {"variants", m.variants}, + {"hidden_size", m.hidden_size}, + {"attention_heads", m.attention_heads}, + {"hidden_layers", m.hidden_layers}, + {"kv_heads", m.kv_heads} }; } @@ -120,5 +135,9 @@ namespace Model j.at("name").get_to(m.name); j.at("author").get_to(m.author); j.at("variants").get_to(m.variants); + j.at("hidden_size").get_to(m.hidden_size); + j.at("attention_heads").get_to(m.attention_heads); + j.at("hidden_layers").get_to(m.hidden_layers); + j.at("kv_heads").get_to(m.kv_heads); } } // namespace Model \ No newline at end of file diff --git a/include/model/model_manager.hpp b/include/model/model_manager.hpp index 7605658..485c50a 100644 --- a/include/model/model_manager.hpp +++ b/include/model/model_manager.hpp @@ -1,8 +1,10 @@ #pragma once +#include "system_monitor.hpp" #include "preset_manager.hpp" #include "model_persistence.hpp" #include "model_loader_config_manager.hpp" +#include "threadpool.hpp" #include #include @@ -33,6 +35,8 @@ namespace Model { static std::atomic seqCounter; + // TODO: Instead of using singleton, i'm thinking of approaching it using a C style implementation + // to avoid the overhead of singleton pattern, and to make it more readable and maintainable. class ModelManager { public: @@ -55,30 +59,31 @@ namespace Model m_currentModelIndex = 0; } - bool unloadModel() + bool unloadModel(const std::string modelName) { std::unique_lock lock(m_mutex); - if (!m_modelLoaded) - { - return false; - } + if (!m_unloadInProgress.empty()) + { + std::cerr << "[ModelManager] Unload already in progress\n"; + return false; + } - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { return false; } - m_unloadInProgress = true; + m_unloadInProgress = modelName; lock.unlock(); // Start async unloading process - auto unloadFuture = unloadModelAsync(); + auto unloadFuture = unloadModelAsync(modelName); // Handle unload completion m_unloadFutures.emplace_back(std::async(std::launch::async, - [this, unloadFuture = std::move(unloadFuture)]() mutable { + [this, unloadFuture = std::move(unloadFuture), modelName]() mutable { if (unloadFuture.get()) { std::cout << "[ModelManager] Successfully unloaded model\n"; @@ -90,9 +95,13 @@ namespace Model { std::unique_lock lock(m_mutex); - m_modelLoaded = false; - m_unloadInProgress = false; - resetModelState(); + m_unloadInProgress = ""; + + if (modelName == m_currentModelName) + { + m_modelLoaded = false; + resetModelState(); + } } })); } @@ -124,27 +133,23 @@ namespace Model return true; } - // Handle already downloaded case - variant->lastSelected = static_cast(std::time(nullptr)); - m_persistence->saveModelData(m_models[m_currentModelIndex]); - // Prevent concurrent model loading - if (m_loadInProgress) { + if (!m_loadInProgress.empty()) { std::cerr << "[ModelManager] Already loading a model, cannot switch now\n"; return false; } - m_loadInProgress = true; - + m_loadInProgress = modelName; + // Release lock before async operations lock.unlock(); // Start async loading process - auto loadFuture = loadModelIntoEngineAsync(); - + auto loadFuture = loadModelIntoEngineAsync(modelName); + // Handle load completion m_loadFutures.emplace_back(std::async(std::launch::async, - [this, loadFuture = std::move(loadFuture)]() mutable { + [this, modelName, loadFuture = std::move(loadFuture), variant]() mutable { bool success = false; try { success = loadFuture.get(); @@ -155,13 +160,17 @@ namespace Model { std::unique_lock lock(m_mutex); - m_loadInProgress = false; + m_loadInProgress = ""; if (success) { m_modelLoaded = true; std::cout << "[ModelManager] Successfully switched models\n"; + variant->lastSelected = static_cast(std::time(nullptr)); + m_persistence->saveModelData(m_models[m_currentModelIndex]); } else { + // Clean up the failed engine + cleanupFailedEngine(modelName); resetModelState(); std::cerr << "[ModelManager] Failed to load model\n"; } @@ -180,6 +189,54 @@ namespace Model return true; } + bool loadModelIntoEngine(const std::string& modelName) + { + std::unique_lock lock(m_mutex); + // Prevent concurrent model loading + if (!m_loadInProgress.empty()) { + std::cerr << "[ModelManager] Already loading a model, cannot load now\n"; + return false; + } + m_loadInProgress = modelName; + // Release lock before async operations + lock.unlock(); + // Start async loading process + auto loadFuture = loadModelIntoEngineAsync(modelName); + // Handle load completion + m_loadFutures.emplace_back(std::async(std::launch::async, + [this, modelName, loadFuture = std::move(loadFuture)]() mutable { + bool success = false; + try { + success = loadFuture.get(); + } + catch (const std::exception& e) { + std::cerr << "[ModelManager] Model load error: " << e.what() << "\n"; + } + { + std::unique_lock lock(m_mutex); + m_loadInProgress = ""; + if (success) { + m_modelLoaded = true; + std::cout << "[ModelManager] Successfully loaded model\n"; + } + else { + // Clean up the failed engine + cleanupFailedEngine(modelName); + std::cerr << "[ModelManager] Failed to load model\n"; + } + } + // Cleanup completed futures + m_loadFutures.erase( + std::remove_if(m_loadFutures.begin(), m_loadFutures.end(), + [](const std::future& f) { + return f.wait_for(std::chrono::seconds(0)) == std::future_status::ready; + }), + m_loadFutures.end() + ); + })); + return true; + } + bool downloadModel(size_t modelIndex, const std::string &variantType) { std::unique_lock lock(m_mutex); @@ -296,6 +353,52 @@ namespace Model return params; } + CompletionParameters buildCompletionParameters(const CompletionRequest& request) { + CompletionParameters params; + + // Set prompt based on request format + if (std::holds_alternative(request.prompt)) { + params.prompt = std::get(request.prompt); + } + else if (std::holds_alternative>(request.prompt)) { + // Join multiple prompts with newlines if array is provided + const auto& prompts = std::get>(request.prompt); + std::ostringstream joined; + for (size_t i = 0; i < prompts.size(); ++i) { + joined << prompts[i]; + if (i < prompts.size() - 1) { + joined << "\n"; + } + } + params.prompt = joined.str(); + } + + // Map parameters from request to our format + if (request.seed.has_value()) { + params.randomSeed = request.seed.value(); + } + + if (request.max_tokens.has_value()) { + params.maxNewTokens = request.max_tokens.value(); + } + else { + // Use a reasonable default if not specified (OpenAI default is 16) + params.maxNewTokens = 16; + } + + // Copy other parameters + params.temperature = request.temperature; + params.topP = request.top_p; + params.streaming = request.stream; + + // Set unique sequence ID based on timestamp + auto now = std::chrono::system_clock::now(); + auto timestamp = std::chrono::duration_cast(now.time_since_epoch()).count(); + params.seqId = static_cast(timestamp * 1000 + seqCounter++); + + return params; + } + ChatCompletionParameters buildChatCompletionParameters( const Chat::ChatHistory& currentChat, const std::string& userInput @@ -386,133 +489,136 @@ namespace Model return completionParams; } - bool stopJob(int jobId) + bool stopJob(int jobId, const std::string modelName) { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; return false; } - m_inferenceEngine->stopJob(jobId); + + // Mark the job as inactive in our tracking map + { + auto it = m_activeJobs.find(jobId); + if (it != m_activeJobs.end()) { + it->second = false; + } + } + + m_inferenceEngines.at(modelName)->stopJob(jobId); return true; } - CompletionResult completeSync(const CompletionParameters& params) + CompletionResult completeSync(const CompletionParameters& params, const std::string modelName) { + CompletionResult emptyResult; + emptyResult.text = ""; + emptyResult.tps = 0.0F; + { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; - CompletionResult result; - result.text = ""; - result.tps = 0.0F; - return result; + return emptyResult; } if (!m_modelLoaded) { std::cerr << "[ModelManager] No model is currently loaded.\n"; - CompletionResult result; - result.text = ""; - result.tps = 0.0F; - return result; + return emptyResult; } } - int jobId = m_inferenceEngine->submitCompletionsJob(params); + int jobId = m_inferenceEngines.at(modelName)->submitCompletionsJob(params); if (jobId < 0) { std::cerr << "[ModelManager] Failed to submit completions job.\n"; - CompletionResult result; - result.text = ""; - result.tps = 0.0F; - return result; + return emptyResult; } // Add job ID with proper synchronization { std::unique_lock lock(m_mutex); m_jobIds.push_back(jobId); + m_activeJobs[jobId] = true; } // Wait for the job to complete - m_inferenceEngine->waitForJob(jobId); + m_inferenceEngines.at(modelName)->waitForJob(jobId); // Get the final result - CompletionResult result = m_inferenceEngine->getJobResult(jobId); + CompletionResult result = m_inferenceEngines.at(modelName)->getJobResult(jobId); // Check for errors - if (m_inferenceEngine->hasJobError(jobId)) { + if (m_inferenceEngines.at(modelName)->hasJobError(jobId)) { std::cerr << "[ModelManager] Error in completion job: " - << m_inferenceEngine->getJobError(jobId) << std::endl; + << m_inferenceEngines.at(modelName)->getJobError(jobId) << std::endl; } // Clean up with proper synchronization { std::unique_lock lock(m_mutex); m_jobIds.erase(std::remove(m_jobIds.begin(), m_jobIds.end(), jobId), m_jobIds.end()); + m_activeJobs.erase(jobId); } return result; } - CompletionResult chatCompleteSync(const ChatCompletionParameters& params, const bool saveChat = true) + CompletionResult chatCompleteSync(const ChatCompletionParameters& params, const std::string modelName, const bool saveChat = true) { + CompletionResult emptyResult; + emptyResult.text = ""; + emptyResult.tps = 0.0F; + { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; - CompletionResult result; - result.text = ""; - result.tps = 0.0F; - return result; + return emptyResult; } if (!m_modelLoaded) { std::cerr << "[ModelManager] No model is currently loaded.\n"; - CompletionResult result; - result.text = ""; - result.tps = 0.0F; - return result; + return emptyResult; } } - int jobId = m_inferenceEngine->submitChatCompletionsJob(params); + int jobId = m_inferenceEngines.at(modelName)->submitChatCompletionsJob(params); if (jobId < 0) { std::cerr << "[ModelManager] Failed to submit chat completions job.\n"; - CompletionResult result; - result.text = ""; - result.tps = 0.0F; - return result; + return emptyResult; } // Add job ID with proper synchronization { std::unique_lock lock(m_mutex); m_jobIds.push_back(jobId); + m_activeJobs[jobId] = true; } // Wait for the job to complete - m_inferenceEngine->waitForJob(jobId); + m_inferenceEngines.at(modelName)->waitForJob(jobId); // Get the final result - CompletionResult result = m_inferenceEngine->getJobResult(jobId); + CompletionResult result = m_inferenceEngines.at(modelName)->getJobResult(jobId); // Check for errors - if (m_inferenceEngine->hasJobError(jobId)) { + if (m_inferenceEngines.at(modelName)->hasJobError(jobId)) { std::cerr << "[ModelManager] Error in chat completion job: " - << m_inferenceEngine->getJobError(jobId) << std::endl; + << m_inferenceEngines.at(modelName)->getJobError(jobId) << std::endl; } // Clean up with proper synchronization { std::unique_lock lock(m_mutex); m_jobIds.erase(std::remove(m_jobIds.begin(), m_jobIds.end(), jobId), m_jobIds.end()); + m_activeJobs.erase(jobId); } // Save the chat history - if (saveChat) + if (saveChat) { auto& chatManager = Chat::ChatManager::getInstance(); auto chatName = chatManager.getChatNameByJobId(jobId); @@ -531,11 +637,12 @@ namespace Model return result; } - int startCompletionJob(const CompletionParameters& params, std::function streamingCallback, const bool saveChat = true) + int startCompletionJob(const CompletionParameters& params, std::function streamingCallback, const std::string modelName, const bool saveChat = true) { { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; return -1; @@ -547,7 +654,7 @@ namespace Model } } - int jobId = m_inferenceEngine->submitCompletionsJob(params); + int jobId = m_inferenceEngines.at(modelName)->submitCompletionsJob(params); if (jobId < 0) { std::cerr << "[ModelManager] Failed to submit completions job.\n"; return -1; @@ -557,16 +664,25 @@ namespace Model { std::unique_lock lock(m_mutex); m_jobIds.push_back(jobId); + m_activeJobs[jobId] = true; } - std::thread([this, jobId, streamingCallback, saveChat]() { + // Use thread pool instead of creating a detached thread + m_threadPool.enqueue([this, jobId, streamingCallback, saveChat, modelName]() { // Poll while job is running or until the engine says it's done while (true) { - if (this->m_inferenceEngine->hasJobError(jobId)) break; + // Check if job was stopped externally + { + std::shared_lock lock(m_mutex); + auto it = m_activeJobs.find(jobId); + if (it == m_activeJobs.end() || !it->second) break; + } + + if (this->m_inferenceEngines.at(modelName)->hasJobError(jobId)) break; - CompletionResult partial = this->m_inferenceEngine->getJobResult(jobId); - bool isFinished = this->m_inferenceEngine->isJobFinished(jobId); + CompletionResult partial = this->m_inferenceEngines.at(modelName)->getJobResult(jobId); + bool isFinished = this->m_inferenceEngines.at(modelName)->isJobFinished(jobId); if (!partial.text.empty()) { // Call the user's callback (no need to lock for the callback) @@ -585,28 +701,28 @@ namespace Model { std::unique_lock lock(m_mutex); m_jobIds.erase(std::remove(m_jobIds.begin(), m_jobIds.end(), jobId), m_jobIds.end()); + m_activeJobs.erase(jobId); } // Reset jobid tracking on chat manager + if (saveChat) { - if (saveChat) + if (!Chat::ChatManager::getInstance().removeJobId(jobId)) { - if (!Chat::ChatManager::getInstance().removeJobId(jobId)) - { - std::cerr << "[ModelManager] Failed to remove job id from chat manager.\n"; - } + std::cerr << "[ModelManager] Failed to remove job id from chat manager.\n"; } } - }).detach(); + }); return jobId; } - int startChatCompletionJob(const ChatCompletionParameters& params, std::function streamingCallback, const bool saveChat = true) + int startChatCompletionJob(const ChatCompletionParameters& params, std::function streamingCallback, const std::string modelName, const bool saveChat = true) { { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; return -1; @@ -618,7 +734,7 @@ namespace Model } } - int jobId = m_inferenceEngine->submitChatCompletionsJob(params); + int jobId = m_inferenceEngines.at(modelName)->submitChatCompletionsJob(params); if (jobId < 0) { std::cerr << "[ModelManager] Failed to submit chat completions job.\n"; return -1; @@ -628,15 +744,24 @@ namespace Model { std::unique_lock lock(m_mutex); m_jobIds.push_back(jobId); + m_activeJobs[jobId] = true; } - std::thread([this, jobId, streamingCallback, saveChat]() { + // Use thread pool instead of creating a detached thread + m_threadPool.enqueue([this, jobId, streamingCallback, saveChat, modelName]() { while (true) { - if (this->m_inferenceEngine->hasJobError(jobId)) break; + // Check if job was stopped externally + { + std::shared_lock lock(m_mutex); + auto it = m_activeJobs.find(jobId); + if (it == m_activeJobs.end() || !it->second) break; + } - CompletionResult partial = this->m_inferenceEngine->getJobResult(jobId); - bool isFinished = this->m_inferenceEngine->isJobFinished(jobId); + if (this->m_inferenceEngines.at(modelName)->hasJobError(jobId)) break; + + CompletionResult partial = this->m_inferenceEngines.at(modelName)->getJobResult(jobId); + bool isFinished = this->m_inferenceEngines.at(modelName)->isJobFinished(jobId); if (!partial.text.empty()) { // Call the user's callback (no need to lock for the callback) @@ -655,6 +780,7 @@ namespace Model { std::unique_lock lock(m_mutex); m_jobIds.erase(std::remove(m_jobIds.begin(), m_jobIds.end(), jobId), m_jobIds.end()); + m_activeJobs.erase(jobId); } if (saveChat) @@ -678,53 +804,53 @@ namespace Model } } } - }).detach(); + }); return jobId; } - bool isJobFinished(int jobId) + bool isJobFinished(int jobId, const std::string modelName) const { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; return true; // No engine means nothing is running } - return m_inferenceEngine->isJobFinished(jobId); + return m_inferenceEngines.at(modelName)->isJobFinished(jobId); } - CompletionResult getJobResult(int jobId) + CompletionResult getJobResult(int jobId, const std::string modelName) const { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; return { {}, "" }; } - return m_inferenceEngine->getJobResult(jobId); + return m_inferenceEngines.at(modelName)->getJobResult(jobId); } - bool hasJobError(int jobId) + bool hasJobError(int jobId, const std::string modelName) const { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; return true; } - return m_inferenceEngine->hasJobError(jobId); + return m_inferenceEngines.at(modelName)->hasJobError(jobId); } - std::string getJobError(int jobId) + std::string getJobError(int jobId, const std::string modelName) const { std::shared_lock lock(m_mutex); - if (!m_inferenceEngine) + if (!m_inferenceEngines.at(modelName)) { std::cerr << "[ModelManager] Inference engine is not initialized.\n"; return "Inference engine not initialized"; } - return m_inferenceEngine->getJobError(jobId); + return m_inferenceEngines.at(modelName)->getJobError(jobId); } //-------------------------------------------------------------------------------------------- @@ -740,19 +866,35 @@ namespace Model Logger::instance().setLevel(LogLevel::SERVER_INFO); Logger::logInfo("Starting model server on port %s", port.c_str()); - // Set inference callbacks - kolosal::ServerAPI::instance().setInferenceCallback( + // Set chat completion callbacks + kolosal::ServerAPI::instance().setChatCompletionCallback( [this](const ChatCompletionRequest& request) { - return this->handleNonStreamingRequest(request); + return this->handleChatCompletionRequest(request); } ); - kolosal::ServerAPI::instance().setStreamingInferenceCallback( + kolosal::ServerAPI::instance().setChatCompletionStreamingCallback( [this](const ChatCompletionRequest& request, const std::string& requestId, int chunkIndex, ChatCompletionChunk& outputChunk) { - return this->handleStreamingRequest(request, requestId, chunkIndex, outputChunk); + return this->handleChatCompletionStreamingRequest(request, requestId, chunkIndex, outputChunk); + } + ); + + // Set completion callbacks + kolosal::ServerAPI::instance().setCompletionCallback( + [this](const CompletionRequest& request) { + return this->handleCompletionRequest(request); + } + ); + + kolosal::ServerAPI::instance().setCompletionStreamingCallback( + [this](const CompletionRequest& request, + const std::string& requestId, + int chunkIndex, + CompletionChunk& outputChunk) { + return this->handleCompletionStreamingRequest(request, requestId, chunkIndex, outputChunk); } ); @@ -771,7 +913,7 @@ namespace Model kolosal::ServerAPI::instance().shutdown(); } - ChatCompletionResponse handleNonStreamingRequest(const ChatCompletionRequest& request) { + ChatCompletionResponse handleChatCompletionRequest(const ChatCompletionRequest& request) { // Build parameters from the incoming request. ChatCompletionParameters params = buildChatCompletionParameters(request); // (The parameters will include the messages and other fields.) @@ -785,20 +927,40 @@ namespace Model return response; } - bool ModelManager::handleStreamingRequest( + CompletionResponse handleCompletionRequest(const CompletionRequest& request) { + // Build parameters from the incoming request + CompletionParameters params = buildCompletionParameters(request); + params.streaming = false; + + // Invoke the synchronous completion method + CompletionResult result = completeSync(params, request.model); + + // Map the engine's result to our CompletionResponse + CompletionResponse response = convertToCompletionResponse(request, result); + return response; + } + + bool handleChatCompletionStreamingRequest( const ChatCompletionRequest& request, const std::string& requestId, int chunkIndex, ChatCompletionChunk& outputChunk) { - // Look up (or create) the StreamingContext for this requestId. - std::shared_ptr ctx; + + // Check if the model name is loaded + if (!m_inferenceEngines.at(request.model)) { + Logger::logError("[ModelManager] Inference engine is not initialized."); + return false; + } + + // Look up (or create) the ChatCompletionStreamingContext for this requestId. + std::shared_ptr ctx; { std::unique_lock lock(m_streamContextsMutex); auto it = m_streamingContexts.find(requestId); if (it == m_streamingContexts.end()) { // For the very first chunk (chunkIndex==0) we create a new context. if (chunkIndex == 0) { - ctx = std::make_shared(); + ctx = std::make_shared(); m_streamingContexts[requestId] = ctx; } else { @@ -825,7 +987,7 @@ namespace Model { std::lock_guard lock(ctx->mtx); ctx->model = request.model; - ctx->jobId = m_inferenceEngine->submitChatCompletionsJob(params); + ctx->jobId = m_inferenceEngines.at(request.model)->submitChatCompletionsJob(params); jobId = ctx->jobId; } @@ -849,18 +1011,30 @@ namespace Model { std::unique_lock lock(m_mutex); m_jobIds.push_back(jobId); + m_activeJobs[jobId] = true; } - // Launch an asynchronous thread that polls the job and accumulates new text. - std::thread([this, jobId, requestId, ctx]() { + // Use thread pool instead of detached thread + m_threadPool.enqueue([this, jobId, request, requestId, ctx]() { std::string lastText; auto startTime = std::chrono::steady_clock::now(); try { while (true) { + // Check if job was stopped externally + { + std::shared_lock lock(m_mutex); + auto it = m_activeJobs.find(jobId); + if (it == m_activeJobs.end() || !it->second) { + std::lock_guard ctxLock(ctx->mtx); + ctx->finished = true; + break; + } + } + // Check if the job has an error - if (this->m_inferenceEngine->hasJobError(jobId)) { - std::string errorMsg = this->m_inferenceEngine->getJobError(jobId); + if (this->m_inferenceEngines.at(request.model)->hasJobError(jobId)) { + std::string errorMsg = this->m_inferenceEngines.at(request.model)->getJobError(jobId); Logger::logError("[ModelManager] Streaming job error for jobId: %d - %s", jobId, errorMsg.c_str()); { @@ -874,8 +1048,8 @@ namespace Model } // Get the current result and check if finished - CompletionResult partial = this->m_inferenceEngine->getJobResult(jobId); - bool isFinished = this->m_inferenceEngine->isJobFinished(jobId); + CompletionResult partial = this->m_inferenceEngines.at(request.model)->getJobResult(jobId); + bool isFinished = this->m_inferenceEngines.at(request.model)->isJobFinished(jobId); // Compute delta text (only new text since last poll). std::string newText; @@ -909,6 +1083,9 @@ namespace Model ctx->cv.notify_all(); break; } + + // Sleep briefly to avoid busy-waiting + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } catch (const std::exception& e) { @@ -928,10 +1105,9 @@ namespace Model this->m_jobIds.erase( std::remove(this->m_jobIds.begin(), this->m_jobIds.end(), jobId), this->m_jobIds.end()); + m_activeJobs.erase(jobId); } - - // We don't erase the streaming context here - that happens when the last chunk is requested - }).detach(); + }); } if (chunkIndex == 0) { @@ -1039,6 +1215,261 @@ namespace Model } } + bool handleCompletionStreamingRequest( + const CompletionRequest& request, + const std::string& requestId, + int chunkIndex, + CompletionChunk& outputChunk) { + + // Get or create streaming context + std::shared_ptr ctx; + { + std::unique_lock lock(m_completionStreamContextsMutex); + auto it = m_completionStreamingContexts.find(requestId); + if (it == m_completionStreamingContexts.end()) { + // For first chunk, create a new context + if (chunkIndex == 0) { + ctx = std::make_shared(); + m_completionStreamingContexts[requestId] = ctx; + } + else { + Logger::logError("[ModelManager] Completion streaming context not found for requestId: %s", + requestId.c_str()); + return false; + } + } + else { + ctx = it->second; + } + } + + // If this is the first call, start the asynchronous job + if (chunkIndex == 0) { + // Build parameters with streaming enabled + CompletionParameters params = buildCompletionParameters(request); + params.streaming = true; + + // Track job ID and model for this request + int jobId = -1; + + { + std::lock_guard lock(ctx->mtx); + ctx->model = request.model; + + // Submit the completion job to the inference engine + jobId = m_inferenceEngines.at(request.model)->submitCompletionsJob(params); + ctx->jobId = jobId; + } + + if (jobId < 0) { + Logger::logError("[ModelManager] Failed to submit completion job for requestId: %s", + requestId.c_str()); + { + std::lock_guard lock(ctx->mtx); + ctx->error = true; + ctx->errorMessage = "Failed to start completion job"; + ctx->finished = true; + } + { + std::unique_lock lock(m_completionStreamContextsMutex); + m_completionStreamingContexts.erase(requestId); + } + return false; + } + + // Add job ID to global tracking + { + std::unique_lock lock(m_mutex); + m_jobIds.push_back(jobId); + m_activeJobs[jobId] = true; + } + + // Use thread pool instead of detached thread + m_threadPool.enqueue([this, jobId, request, requestId, ctx]() { + std::string lastText; + auto startTime = std::chrono::steady_clock::now(); + + try { + while (true) { + // Check if job was stopped externally + { + std::shared_lock lock(m_mutex); + auto it = m_activeJobs.find(jobId); + if (it == m_activeJobs.end() || !it->second) { + std::lock_guard ctxLock(ctx->mtx); + ctx->finished = true; + break; + } + } + + // Check if the job has an error + if (this->m_inferenceEngines.at(request.model)->hasJobError(jobId)) { + std::string errorMsg = this->m_inferenceEngines.at(request.model)->getJobError(jobId); + Logger::logError("[ModelManager] Streaming completion job error for jobId: %d - %s", + jobId, errorMsg.c_str()); + { + std::lock_guard lock(ctx->mtx); + ctx->error = true; + ctx->errorMessage = errorMsg; + ctx->finished = true; + } + ctx->cv.notify_all(); + break; + } + + // Get the current result and check if finished + CompletionResult partial = this->m_inferenceEngines.at(request.model)->getJobResult(jobId); + bool isFinished = this->m_inferenceEngines.at(request.model)->isJobFinished(jobId); + + // Compute delta text (only new text since last poll) + std::string newText; + if (partial.text.size() > lastText.size()) { + newText = partial.text.substr(lastText.size()); + lastText = partial.text; + } + + // If we have new text, add it to the chunks + if (!newText.empty()) { + { + std::lock_guard lock(ctx->mtx); + ctx->fullText = lastText; + ctx->chunks.push_back(newText); + } + ctx->cv.notify_all(); + } + + // If the job is finished, set the finished flag and break + if (isFinished) { + auto endTime = std::chrono::steady_clock::now(); + auto durationMs = std::chrono::duration_cast( + endTime - startTime).count(); + + Logger::logInfo("[ModelManager] Streaming completion job %d completed in %lld ms", + jobId, durationMs); + + { + std::lock_guard lock(ctx->mtx); + ctx->finished = true; + } + ctx->cv.notify_all(); + break; + } + + // Sleep briefly to avoid busy-waiting + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } + catch (const std::exception& e) { + Logger::logError("[ModelManager] Exception in completion streaming thread: %s", e.what()); + { + std::lock_guard lock(ctx->mtx); + ctx->error = true; + ctx->errorMessage = e.what(); + ctx->finished = true; + } + ctx->cv.notify_all(); + } + + // Clean up job ID tracking + { + std::unique_lock lock(this->m_mutex); + this->m_jobIds.erase( + std::remove(this->m_jobIds.begin(), this->m_jobIds.end(), jobId), + this->m_jobIds.end()); + m_activeJobs.erase(jobId); + } + }); + } + + // Prepare the chunk response + outputChunk.id = requestId; + outputChunk.model = request.model; + outputChunk.created = static_cast(std::time(nullptr)); + outputChunk.choices.clear(); + + // For first chunk, just create an empty choice + if (chunkIndex == 0) { + CompletionChunkChoice choice; + choice.index = 0; + choice.text = ""; + choice.finish_reason = ""; // Use empty string instead of nullptr + outputChunk.choices.push_back(choice); + return true; + } + // For subsequent chunks, wait for content + else { + std::unique_lock lock(ctx->mtx); + + // Wait with timeout for the chunk to be available + bool result = ctx->cv.wait_for(lock, std::chrono::seconds(30), [ctx, chunkIndex]() { + return (ctx->chunks.size() >= static_cast(chunkIndex)) || + ctx->finished || ctx->error; + }); + + if (!result) { + // Timeout occurred + Logger::logError("[ModelManager] Timeout waiting for completion chunk %d for requestId %s", + chunkIndex, requestId.c_str()); + + // Keep the lock when we check if this is the last message + std::unique_lock glock(m_completionStreamContextsMutex); + m_completionStreamingContexts.erase(requestId); + return false; + } + + // Handle errors - still holding the lock + if (ctx->error) { + Logger::logError("[ModelManager] Error in streaming completion for requestId %s: %s", + requestId.c_str(), ctx->errorMessage.c_str()); + + // Keep the lock when we check if this is the last message + std::unique_lock glock(m_completionStreamContextsMutex); + m_completionStreamingContexts.erase(requestId); + return false; + } + + CompletionChunkChoice choice; + choice.index = 0; + + // Check for completion state while still holding the lock + bool hasChunk = ctx->chunks.size() >= static_cast(chunkIndex); + bool isFinished = ctx->finished; + bool isLastChunk = false; + + if (hasChunk) { + // Get the content for this chunk while holding the lock + choice.text = ctx->chunks[chunkIndex - 1]; + + // Determine if this is the last chunk while safely protected by the lock + isLastChunk = isFinished && (ctx->chunks.size() == static_cast(chunkIndex)); + choice.finish_reason = isLastChunk ? "stop" : ""; // Use empty string instead of nullptr + } + else if (isFinished) { + // No chunk but job is finished - send empty final chunk + choice.text = ""; + choice.finish_reason = "stop"; + isLastChunk = true; + } + else { + // We have no chunk yet but still waiting + choice.text = ""; + choice.finish_reason = ""; // Use empty string instead of nullptr + } + + // Release the ctx lock before acquiring the global contexts lock to avoid deadlock + lock.unlock(); + + // Clean up if this is the last chunk + if (isLastChunk) { + std::unique_lock glock(m_completionStreamContextsMutex); + m_completionStreamingContexts.erase(requestId); + } + + outputChunk.choices.push_back(choice); + return !isLastChunk; // Return true if more chunks remain, false if this is the last one + } + } + std::string getCurrentVariantForModel(const std::string& modelName) const { auto it = m_modelVariantMap.find(modelName); @@ -1076,12 +1507,11 @@ namespace Model if (modelIndex == m_currentModelIndex && variantType == m_currentVariantType) { - unloadModel(); + unloadModel(m_models[modelIndex].name); } // Call the persistence layer to delete the file - passing the variant type instead of the variant - auto fut = m_persistence->deleteModelVariant(m_models[modelIndex], variantType); - fut.get(); // Wait for deletion to complete. + m_persistence->deleteModelVariant(m_models[modelIndex], variantType); return true; } @@ -1092,6 +1522,28 @@ namespace Model m_modelLoaded = false; } + void cleanupFailedEngine(const std::string& modelName) { + auto it = m_inferenceEngines.find(modelName); + if (it != m_inferenceEngines.end()) { + // Release resources if the engine implementation requires it + if (it->second) { + it->second->unloadModel(); + } + m_inferenceEngines.erase(it); + } + } + + bool retryModelLoad(const std::string& modelName, const std::string& variantType) { + // First clean up any previous failed attempt + { + std::unique_lock lock(m_mutex); + cleanupFailedEngine(modelName); + } + + // Then try to switch to this model again + return switchModel(modelName, variantType); + } + bool isCurrentlyGenerating() const { std::shared_lock lock(m_mutex); @@ -1111,17 +1563,145 @@ namespace Model } bool isLoadInProgress() const + { + std::shared_lock lock(m_mutex); + return !m_loadInProgress.empty(); + } + + std::string getCurrentOnLoadingModel() const { std::shared_lock lock(m_mutex); return m_loadInProgress; } bool isUnloadInProgress() const + { + std::shared_lock lock(m_mutex); + return !m_unloadInProgress.empty(); + } + + std::string getCurrentOnUnloadingModel() const { std::shared_lock lock(m_mutex); return m_unloadInProgress; } + bool isModelLoaded(const std::string& modelName) const + { + std::shared_lock lock(m_mutex); + auto it = m_inferenceEngines.find(modelName); + if (it != m_inferenceEngines.end()) + { + return it->second != nullptr; + } + return false; + } + + bool hasEnoughMemoryForModel(const std::string& modelName, float& memoryReqBuff, float& kvReqBuff) { + auto it = m_modelNameToIndex.find(modelName); + if (it == m_modelNameToIndex.end()) { + std::cerr << "[ModelManager] Model not found: " << modelName << "\n"; + return false; + } + + size_t modelIndex = it->second; + const auto& model = m_models[modelIndex]; + const auto& variant = model.variants.at( + getCurrentVariantForModel(modelName) + ); + + // Calculate model size in bytes (convert from GB) + size_t modelSizeBytes = static_cast(variant.size * 1024 * 1024 * 1024); + + // Calculate KV cache size based on model parameters + // KV cache formula: 2 (key & value) * hidden_size * hidden_layers * max_seq_length * bytes_per_token + const size_t MAX_SEQUENCE_LENGTH = ModelLoaderConfigManager::getInstance().getConfig().n_ctx; + + float_t kvCacheSizeBytes = 4 * + model.hidden_size * + model.hidden_layers * + MAX_SEQUENCE_LENGTH; + + // Update the buffers in MB + memoryReqBuff = (modelSizeBytes) / (1024 * 1024); + kvReqBuff = (kvCacheSizeBytes) / (1024 * 1024); + + // Check if we have enough memory using SystemMonitor + auto& sysMonitor = SystemMonitor::getInstance(); + bool hasEnoughMemory = sysMonitor.hasEnoughMemoryForModel( + modelSizeBytes, + kvCacheSizeBytes + ); + + return hasEnoughMemory; + } + + bool hasEnoughMemoryForModel(const std::string& modelName) { + auto it = m_modelNameToIndex.find(modelName); + if (it == m_modelNameToIndex.end()) { + std::cerr << "[ModelManager] Model not found: " << modelName << "\n"; + return false; + } + + size_t modelIndex = it->second; + const auto& model = m_models[modelIndex]; + const auto& variant = model.variants.at( + getCurrentVariantForModel(modelName) + ); + + // Calculate model size in bytes (convert from GB) + size_t modelSizeBytes = static_cast(variant.size * 1024 * 1024 * 1024); + + // Calculate KV cache size based on model parameters + // KV cache formula: 2 (key & value) * hidden_size * hidden_layers * max_seq_length * bytes_per_token + const size_t MAX_SEQUENCE_LENGTH = ModelLoaderConfigManager::getInstance().getConfig().n_ctx; + const size_t BYTES_PER_TOKEN = 2; // Assuming FP16 precision kv (2 bytes) + + size_t kvCacheSizeBytes = 4 * + model.hidden_size * + model.hidden_layers * + MAX_SEQUENCE_LENGTH; + + // Check if we have enough memory using SystemMonitor + auto& sysMonitor = SystemMonitor::getInstance(); + bool hasEnoughMemory = sysMonitor.hasEnoughMemoryForModel( + modelSizeBytes, + kvCacheSizeBytes + ); + + return hasEnoughMemory; + } + + bool addCustomModel(const Model::ModelData modelData) + { + std::unique_lock lock(m_mutex); + + if (m_modelNameToIndex.count(modelData.name)) { + std::cerr << "[ModelManager] Model with name '" << modelData.name << "' already exists.\n"; + return false; + } + + if (modelData.variants.empty()) { + std::cerr << "[ModelManager] Cannot add model with no variants\n"; + return false; + } + + m_models.push_back(modelData); + m_modelNameToIndex[modelData.name] = m_models.size() - 1; + + // save the model to persistence + m_persistence->saveModelData(modelData); + + // Update the model variant map + m_modelVariantMap[modelData.name] = modelData.variants.begin()->first; + return true; + } + + const bool isUsingGpu() const { + std::shared_lock lock(m_mutex); + return m_isVulkanBackend; + } + private: explicit ModelManager(std::unique_ptr persistence, const bool async = true) : m_persistence(std::move(persistence)) @@ -1129,7 +1709,6 @@ namespace Model , m_currentModelIndex(0) , m_inferenceLibHandle(nullptr) , m_createInferenceEnginePtr(nullptr) - , m_inferenceEngine(nullptr) , m_modelLoaded(false) , m_modelGenerationInProgress(false) { @@ -1141,9 +1720,9 @@ namespace Model // Load inference engine backend and models synchronously loadModels(); - bool useVulkan = useVulkanBackend(); + m_isVulkanBackend = useVulkanBackend(); std::string backendName = "InferenceEngineLib.dll"; - if (useVulkan) + if (m_isVulkanBackend) { backendName = "InferenceEngineLibVulkan.dll"; } @@ -1165,9 +1744,19 @@ namespace Model m_initializationFuture.wait(); } - if (m_inferenceEngine && m_destroyInferenceEnginePtr) { - m_destroyInferenceEnginePtr(m_inferenceEngine); - m_inferenceEngine = nullptr; + // Clean up all inference engines + if (!m_inferenceEngines.empty()) + { + for (auto& [modelName, engine] : m_inferenceEngines) + { + if (engine && m_destroyInferenceEnginePtr) + { + m_destroyInferenceEnginePtr(engine); + } + + engine = nullptr; + m_inferenceEngines.erase(modelName); + } } if (m_inferenceLibHandle) { @@ -1193,48 +1782,44 @@ namespace Model } void startAsyncInitialization() { - m_initializationFuture = std::async(std::launch::async, [this]() { -#ifdef DEBUG - std::cout << "[ModelManager] Initializing model manager" << std::endl; -#endif - - // Run model loading and Vulkan check in parallel - auto modelsFuture = std::async(std::launch::async, &ModelManager::loadModels, this); - auto vulkanFuture = std::async(std::launch::async, &ModelManager::useVulkanBackend, this); - - modelsFuture.get(); - bool useVulkan = vulkanFuture.get(); + m_initializationFuture = m_threadPool.enqueue([this]() { + auto& sysMonitor = SystemMonitor::getInstance(); + sysMonitor.update(); + loadModels(); // blocking + m_isVulkanBackend = useVulkanBackend(); std::string backendName = "InferenceEngineLib.dll"; - if (useVulkan) { - backendName = "InferenceEngineLibVulkan.dll"; - } -#ifdef DEBUG - std::cout << "[ModelManager] Using backend: " << backendName << std::endl; -#endif + if (m_isVulkanBackend) + { + backendName = "InferenceEngineLibVulkan.dll"; + SystemMonitor::getInstance().initializeGpuMonitoring(); + } - if (!loadInferenceEngineDynamically(backendName.c_str())) { - std::cerr << "[ModelManager] Failed to load inference engine for backend: " - << backendName << std::endl; + if (!loadInferenceEngineDynamically(backendName)) { + std::cerr << "Failed to load inference engine\n"; return; } + std::optional name; { - std::unique_lock lock(m_mutex); - m_loadInProgress = true; + std::unique_lock lock(m_mutex); + if (m_currentModelName.has_value()) { + m_loadInProgress = m_currentModelName.value(); + name = m_currentModelName; + } } - // Start async model loading - auto modelLoadFuture = loadModelIntoEngineAsync(); - if (!modelLoadFuture.get()) { // Wait for async load to complete - resetModelState(); + if (name.has_value()) { + auto future = loadModelIntoEngineAsync(name.value()); + if (!future.get()) { + std::cerr << "Failed to load model into engine\n"; + resetModelState(); + } } - { - std::unique_lock lock(m_mutex); - m_loadInProgress = false; - } + std::unique_lock lock(m_mutex); + m_loadInProgress.clear(); }); } @@ -1420,6 +2005,8 @@ namespace Model if (!variant) return; + const std::string modelName = m_models[modelIndex].name; + variant->downloadProgress = 0.01f; // 0% looks like no progress // Begin the asynchronous download - passing the variant type rather than the variant itself @@ -1427,7 +2014,7 @@ namespace Model // Chain a continuation that waits for the download to complete. m_downloadFutures.emplace_back(std::async(std::launch::async, - [this, modelIndex, variantType, fut = std::move(downloadFuture)]() mutable { + [this, modelIndex, modelName, variantType, fut = std::move(downloadFuture)]() mutable { // Wait for the download to finish. fut.wait(); @@ -1439,7 +2026,7 @@ namespace Model // Unlock before loading the model. lock.unlock(); - auto loadFuture = loadModelIntoEngineAsync(); + auto loadFuture = loadModelIntoEngineAsync(modelName); if (!loadFuture.get()) { std::unique_lock restoreLock(m_mutex); @@ -1697,70 +2284,60 @@ namespace Model << backendName << std::endl; #endif - m_inferenceEngine = m_createInferenceEnginePtr(); - if (!m_inferenceEngine) { - std::cerr << "[ModelManager] Failed to get InferenceEngine instance from " - << backendName << std::endl; - return false; - } #endif return true; } - std::future ModelManager::loadModelIntoEngineAsync() { - // Capture needed data under lock + std::future loadModelIntoEngineAsync(const std::string& modelName) { + if (!hasEnoughMemoryForModel(modelName)) { + std::promise promise; + promise.set_value(false); + return promise.get_future(); + } + std::optional modelDir; { - std::unique_lock lock(m_mutex); - - if (!m_currentModelName) { - std::cerr << "[ModelManager] No model selected\n"; - m_modelLoaded = false; - return std::async(std::launch::deferred, [] { return false; }); - } - - ModelVariant* variant = getVariantLocked(m_currentModelIndex, m_currentVariantType); - if (!variant || !variant->isDownloaded || !std::filesystem::exists(variant->path)) { - m_modelLoaded = false; - return std::async(std::launch::deferred, [] { return false; }); + std::shared_lock lock(m_mutex); + int index = m_modelNameToIndex[modelName]; + auto variant = getVariantLocked(index, getCurrentVariantForModel(modelName)); + if (!variant || !variant->isDownloaded) { + std::promise promise; + promise.set_value(false); + return promise.get_future(); } modelDir = std::filesystem::absolute( - variant->path.substr(0, variant->path.find_last_of("/\\")) - ).string(); + variant->path.substr(0, variant->path.find_last_of("/\\"))).string(); } - // Launch heavy loading in async task - return std::async(std::launch::async, [this, modelDir]() { - try { - bool success = m_inferenceEngine->loadModel(modelDir->c_str(), - ModelLoaderConfigManager::getInstance().getConfig()); + return m_threadPool.enqueue([this, modelName, modelDir]() { + std::cout << "[ModelManager] size of inference engines: " << sizeof(m_inferenceEngines) << std::endl; - { - std::unique_lock lock(m_mutex); - m_modelLoaded = success; - } + auto engine = m_createInferenceEnginePtr(); + if (!engine) return false; - if (success) { - std::cout << "[ModelManager] Loaded model: " << *modelDir << "\n"; - } - return success; + bool success = engine->loadModel(modelDir->c_str(), ModelLoaderConfigManager::getInstance().getConfig()); + if (success) { + std::unique_lock lock(m_mutex); + m_inferenceEngines[modelName] = engine; + std::cout << "[ModelManager] size of inference engines: " << sizeof(m_inferenceEngines) << std::endl; + m_modelLoaded = true; } - catch (const std::exception& e) { - std::cerr << "[ModelManager] Load failed: " << e.what() << "\n"; - std::unique_lock lock(m_mutex); - m_modelLoaded = false; - return false; + else { + std::cerr << "Model load failed\n"; } + + return success; }); } - std::future ModelManager::unloadModelAsync() { + std::future ModelManager::unloadModelAsync(const std::string modelName) { // Capture current loaded state under lock bool isLoaded; { std::unique_lock lock(m_mutex); - isLoaded = m_modelLoaded; + // Check if the model is loaded in m_inferenceEngines + isLoaded = m_inferenceEngines.find(modelName) != m_inferenceEngines.end(); if (!isLoaded) { std::cerr << "[ModelManager] No model loaded to unload\n"; @@ -1769,9 +2346,12 @@ namespace Model } // Launch heavy unloading in async task - return std::async(std::launch::async, [this]() { + return std::async(std::launch::async, [this, modelName]() { try { - bool success = m_inferenceEngine->unloadModel(); + bool success = m_inferenceEngines.at(modelName)->unloadModel(); + // delete the engine instance + m_destroyInferenceEnginePtr(m_inferenceEngines.at(modelName)); + m_inferenceEngines.erase(modelName); { std::unique_lock lock(m_mutex); @@ -1797,15 +2377,19 @@ namespace Model void stopAllJobs() { - std::vector jobIdsCopy; + std::vector jobs; { - std::shared_lock lock(m_mutex); - jobIdsCopy = m_jobIds; + std::shared_lock lock(m_mutex); + jobs = m_jobIds; + + for (int id : jobs) { + m_activeJobs[id] = false; // Mark jobs as inactive + } } - for (auto jobId : jobIdsCopy) - { - stopJob(jobId); + for (int id : jobs) { + for (auto& [name, engine] : m_inferenceEngines) + engine->stopJob(id); } } @@ -1845,6 +2429,41 @@ namespace Model return response; } + static CompletionResponse convertToCompletionResponse(const CompletionRequest& request, const CompletionResult& result) { + CompletionResponse response; + response.model = request.model; + + // Create a choice with the generated text + CompletionChoice choice; + choice.index = 0; + choice.text = result.text; + choice.finish_reason = "stop"; // Assuming completion finished normally + + response.choices.push_back(choice); + + // Set usage statistics - this is an estimation + int promptLength = 0; + if (std::holds_alternative(request.prompt)) { + promptLength = std::get(request.prompt).size() / 4; // Rough token estimation + } + else if (std::holds_alternative>(request.prompt)) { + for (const auto& p : std::get>(request.prompt)) { + promptLength += p.size() / 4; + } + } + + int completionLength = result.text.size() / 4; // Rough token estimation + + response.usage.prompt_tokens = promptLength; + response.usage.completion_tokens = completionLength; + response.usage.total_tokens = promptLength + completionLength; + + return response; + } + + ThreadPool m_threadPool{ std::max(4u, std::thread::hardware_concurrency() - 1) }; + std::unordered_map> m_activeJobs; + mutable std::shared_mutex m_mutex; std::unique_ptr m_persistence; std::vector m_models; @@ -1855,14 +2474,16 @@ namespace Model std::vector> m_downloadFutures; std::future m_engineLoadFuture; std::future m_initializationFuture; + std::future m_persistenceFuture; std::vector> m_loadFutures; std::vector> m_unloadFutures; - std::atomic m_unloadInProgress{ false }; - std::atomic m_loadInProgress{ false }; + std::string m_unloadInProgress; + std::string m_loadInProgress; std::unordered_map m_modelVariantMap; std::atomic m_modelLoaded{ false }; std::atomic m_modelGenerationInProgress{ false }; std::vector m_jobIds; + bool m_isVulkanBackend{ false }; #ifdef _WIN32 HMODULE m_inferenceLibHandle = nullptr; @@ -1871,10 +2492,10 @@ namespace Model CreateInferenceEngineFunc* m_createInferenceEnginePtr = nullptr; DestroyInferenceEngineFunc* m_destroyInferenceEnginePtr = nullptr; - IInferenceEngine* m_inferenceEngine = nullptr; + std::map m_inferenceEngines; // Server related - struct StreamingContext { + struct ChatCompletionStreamingContext { std::mutex mtx; std::condition_variable cv; std::vector chunks; @@ -1885,8 +2506,23 @@ namespace Model bool error = false; }; std::mutex m_streamContextsMutex; - std::unordered_map> + std::unordered_map> m_streamingContexts; + + struct CompletionStreamingContext { + std::mutex mtx; + std::condition_variable cv; + std::string model; + int jobId = -1; + std::vector chunks; + bool finished = false; + bool error = false; + std::string errorMessage; + std::string fullText; // Accumulated full text + }; + std::mutex m_completionStreamContextsMutex; + std::unordered_map> + m_completionStreamingContexts; }; inline void initializeModelManager(const bool async = true) diff --git a/include/system_monitor.hpp b/include/system_monitor.hpp new file mode 100644 index 0000000..f544b70 --- /dev/null +++ b/include/system_monitor.hpp @@ -0,0 +1,399 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#include // For IDXGIAdapter3 and QueryVideoMemoryInfo +#pragma comment(lib, "dxgi.lib") +#else +#include +#include +#include +#ifdef __APPLE__ +#include +#include +#include +#include +#include +#endif +#endif + +constexpr size_t GB = 1024 * 1024 * 1024; + +class SystemMonitor { +public: + static SystemMonitor& getInstance() { + static SystemMonitor instance; + return instance; + } + + // CPU/Memory statistics + size_t getTotalSystemMemory() { + return m_totalMemory; + } + size_t getAvailableSystemMemory() { + return m_availableMemory; + } + size_t getUsedMemoryByProcess() { + return m_usedMemory; + } + float getCpuUsagePercentage() { + return m_cpuUsage; + } + + // GPU Memory statistics using DirectX (global memory, not per process) + bool hasGpuSupport() const { return m_gpuMonitoringSupported; } + size_t getTotalGpuMemory() { + if (!m_gpuMonitoringSupported) return 0; + return m_totalGpuMemory; + } + size_t getAvailableGpuMemory() { + if (!m_gpuMonitoringSupported) return 0; + return m_availableGpuMemory; + } + size_t getUsedGpuMemoryByProcess() { + if (!m_gpuMonitoringSupported) return 0; + return m_usedGpuMemory; + } + + // Initialize GPU monitoring with DirectX backend (Windows only) + void initializeGpuMonitoring() { +#ifdef _WIN32 + std::lock_guard lock(m_gpuMutex); + initializeDirectX(); +#else + m_gpuMonitoringSupported = false; +#endif + } + + // Calculate if there's enough memory to load a model + bool hasEnoughMemoryForModel(size_t modelSizeBytes, size_t kvCacheSizeBytes) { + // Update stats to get the latest values + update(); + + // Calculate total required memory + size_t totalRequiredMemory = modelSizeBytes + kvCacheSizeBytes; + + // Add 20% overhead for safety margin + totalRequiredMemory = static_cast(totalRequiredMemory * 1.2); + + if (m_gpuMonitoringSupported) { + // Check if GPU has enough available memory + if (m_availableGpuMemory < totalRequiredMemory) { + return false; + } + return true; + } + else { + // Check if system RAM has enough memory (threshold of 2GB more) + if (m_availableMemory + 2 * GB < totalRequiredMemory) { + return false; + } + return true; + } + } + + // Update monitoring state - call periodically + void update() { + auto currentTime = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast( + currentTime - m_lastCpuMeasurement).count(); + + // Only update every 1000ms to avoid excessive CPU usage + if (elapsed >= 1000) { + updateCpuUsage(); + updateMemoryStats(); + + if (m_gpuMonitoringSupported) { + updateGpuStats(); + } + + m_lastCpuMeasurement = currentTime; + } + } + +private: + SystemMonitor() : m_lastCpuMeasurement(std::chrono::steady_clock::now()) + { +#ifdef _WIN32 + ZeroMemory(&m_prevSysKernelTime, sizeof(FILETIME)); + ZeroMemory(&m_prevSysUserTime, sizeof(FILETIME)); + ZeroMemory(&m_prevProcKernelTime, sizeof(FILETIME)); + ZeroMemory(&m_prevProcUserTime, sizeof(FILETIME)); +#else + m_prevTotalUser = 0; + m_prevTotalUserLow = 0; + m_prevTotalSys = 0; + m_prevTotalIdle = 0; + m_prevProcessTotalUser = 0; + m_prevProcessTotalSys = 0; +#endif + + // Initialize memory stats and CPU usage tracking + updateMemoryStats(); + updateCpuUsage(); + } + ~SystemMonitor() { +#ifdef _WIN32 + if (m_dxgiAdapter) { + m_dxgiAdapter->Release(); + m_dxgiAdapter = nullptr; + } + if (m_dxgiFactory) { + m_dxgiFactory->Release(); + m_dxgiFactory = nullptr; + } +#endif + } + + // CPU monitoring members + std::atomic m_cpuUsage{ 0.0f }; + std::atomic m_usedMemory{ 0 }; + std::atomic m_availableMemory{ 0 }; + std::atomic m_totalMemory{ 0 }; + std::chrono::steady_clock::time_point m_lastCpuMeasurement; + std::mutex m_cpuMutex; + +#ifdef _WIN32 + FILETIME m_prevSysKernelTime; + FILETIME m_prevSysUserTime; + FILETIME m_prevProcKernelTime; + FILETIME m_prevProcUserTime; +#else + unsigned long long m_prevTotalUser; + unsigned long long m_prevTotalUserLow; + unsigned long long m_prevTotalSys; + unsigned long long m_prevTotalIdle; + unsigned long long m_prevProcessTotalUser; + unsigned long long m_prevProcessTotalSys; +#endif + + // GPU monitoring members + bool m_gpuMonitoringSupported{ false }; + std::atomic m_totalGpuMemory{ 0 }; + std::atomic m_availableGpuMemory{ 0 }; + std::atomic m_usedGpuMemory{ 0 }; + std::mutex m_gpuMutex; + +#ifdef _WIN32 + // DirectX-specific members + IDXGIFactory1* m_dxgiFactory{ nullptr }; + IDXGIAdapter3* m_dxgiAdapter{ nullptr }; +#endif + + // Private helper methods + + void updateCpuUsage() { +#ifdef _WIN32 + FILETIME sysIdleTime, sysKernelTime, sysUserTime; + FILETIME procCreationTime, procExitTime, procKernelTime, procUserTime; + + // Get system times + if (!GetSystemTimes(&sysIdleTime, &sysKernelTime, &sysUserTime)) { + return; + } + + // Get process times + HANDLE hProcess = GetCurrentProcess(); + if (!GetProcessTimes(hProcess, &procCreationTime, &procExitTime, &procKernelTime, &procUserTime)) { + return; + } + + // First call - just store previous times and return + if (m_prevSysKernelTime.dwLowDateTime == 0 && m_prevSysKernelTime.dwHighDateTime == 0) { + m_prevSysKernelTime = sysKernelTime; + m_prevSysUserTime = sysUserTime; + m_prevProcKernelTime = procKernelTime; + m_prevProcUserTime = procUserTime; + return; + } + + // Convert FILETIME to ULARGE_INTEGER for arithmetic + ULARGE_INTEGER sysKernelTimeULI, sysUserTimeULI; + ULARGE_INTEGER procKernelTimeULI, procUserTimeULI; + ULARGE_INTEGER prevSysKernelTimeULI, prevSysUserTimeULI; + ULARGE_INTEGER prevProcKernelTimeULI, prevProcUserTimeULI; + + sysKernelTimeULI.LowPart = sysKernelTime.dwLowDateTime; + sysKernelTimeULI.HighPart = sysKernelTime.dwHighDateTime; + sysUserTimeULI.LowPart = sysUserTime.dwLowDateTime; + sysUserTimeULI.HighPart = sysUserTime.dwHighDateTime; + + procKernelTimeULI.LowPart = procKernelTime.dwLowDateTime; + procKernelTimeULI.HighPart = procKernelTime.dwHighDateTime; + procUserTimeULI.LowPart = procUserTime.dwLowDateTime; + procUserTimeULI.HighPart = procUserTime.dwHighDateTime; + + prevSysKernelTimeULI.LowPart = m_prevSysKernelTime.dwLowDateTime; + prevSysKernelTimeULI.HighPart = m_prevSysKernelTime.dwHighDateTime; + prevSysUserTimeULI.LowPart = m_prevSysUserTime.dwLowDateTime; + prevSysUserTimeULI.HighPart = m_prevSysUserTime.dwHighDateTime; + + prevProcKernelTimeULI.LowPart = m_prevProcKernelTime.dwLowDateTime; + prevProcKernelTimeULI.HighPart = m_prevProcKernelTime.dwHighDateTime; + prevProcUserTimeULI.LowPart = m_prevProcUserTime.dwLowDateTime; + prevProcUserTimeULI.HighPart = m_prevProcUserTime.dwHighDateTime; + + // Calculate time differences + ULONGLONG sysTimeChange = (sysKernelTimeULI.QuadPart - prevSysKernelTimeULI.QuadPart) + + (sysUserTimeULI.QuadPart - prevSysUserTimeULI.QuadPart); + + ULONGLONG procTimeChange = (procKernelTimeULI.QuadPart - prevProcKernelTimeULI.QuadPart) + + (procUserTimeULI.QuadPart - prevProcUserTimeULI.QuadPart); + + // Calculate CPU usage percentage for the process + if (sysTimeChange > 0) { + m_cpuUsage = (float)((100.0 * procTimeChange) / sysTimeChange); + if (m_cpuUsage > 100.0f) m_cpuUsage = 100.0f; + } + + // Store current times for next measurement + m_prevSysKernelTime = sysKernelTime; + m_prevSysUserTime = sysUserTime; + m_prevProcKernelTime = procKernelTime; + m_prevProcUserTime = procUserTime; +#else + m_cpuUsage = 0.0f; +#endif + } + + void updateMemoryStats() { +#ifdef _WIN32 + MEMORYSTATUSEX memInfo; + memInfo.dwLength = sizeof(MEMORYSTATUSEX); + GlobalMemoryStatusEx(&memInfo); + m_totalMemory = memInfo.ullTotalPhys; + m_availableMemory = memInfo.ullAvailPhys; + + PROCESS_MEMORY_COUNTERS_EX pmc; + if (GetProcessMemoryInfo(GetCurrentProcess(), (PROCESS_MEMORY_COUNTERS*)&pmc, sizeof(pmc))) { + m_usedMemory = pmc.PrivateUsage; + } +#elif defined(__APPLE__) + mach_port_t host_port = mach_host_self(); + vm_size_t page_size; + host_page_size(host_port, &page_size); + + vm_statistics64_data_t vm_stats; + mach_msg_type_number_t count = HOST_VM_INFO64_COUNT; + if (host_statistics64(host_port, HOST_VM_INFO64, (host_info64_t)&vm_stats, &count) == KERN_SUCCESS) { + m_availableMemory = (vm_stats.free_count + vm_stats.inactive_count) * page_size; + } + + int mib[2] = { CTL_HW, HW_MEMSIZE }; + uint64_t total_memory = 0; + size_t len = sizeof(total_memory); + if (sysctl(mib, 2, &total_memory, &len, NULL, 0) == 0) { + m_totalMemory = total_memory; + } + + struct rusage usage; + if (getrusage(RUSAGE_SELF, &usage) == 0) { + m_usedMemory = usage.ru_maxrss * 1024; + } +#else + struct sysinfo memInfo; + if (sysinfo(&memInfo) == 0) { + m_totalMemory = memInfo.totalram * memInfo.mem_unit; + m_availableMemory = memInfo.freeram * memInfo.mem_unit; + } + FILE* fp = fopen("/proc/self/statm", "r"); + if (fp) { + unsigned long vm = 0, rss = 0; + if (fscanf(fp, "%lu %lu", &vm, &rss) == 2) { + m_usedMemory = rss * sysconf(_SC_PAGESIZE); + } + fclose(fp); + } +#endif + } + + void updateGpuStats() { +#ifdef _WIN32 + if (m_gpuMonitoringSupported) { + updateDirectXGpuStats(); + } +#else + m_totalGpuMemory = 0; + m_availableGpuMemory = 0; + m_usedGpuMemory = 0; +#endif + } + +#ifdef _WIN32 + // DirectX (DXGI) GPU monitoring methods + + void initializeDirectX() { + HRESULT hr = CreateDXGIFactory1(__uuidof(IDXGIFactory1), reinterpret_cast(&m_dxgiFactory)); + if (FAILED(hr)) { + std::cerr << "[SystemMonitor] Failed to create DXGI Factory" << std::endl; + m_gpuMonitoringSupported = false; + return; + } + + // Enumerate adapters and choose a dedicated one (NVIDIA or AMD) + IDXGIAdapter* adapter = nullptr; + IDXGIAdapter3* dedicatedAdapter = nullptr; + for (UINT i = 0; m_dxgiFactory->EnumAdapters(i, &adapter) != DXGI_ERROR_NOT_FOUND; i++) { + IDXGIAdapter3* adapter3 = nullptr; + hr = adapter->QueryInterface(__uuidof(IDXGIAdapter3), reinterpret_cast(&adapter3)); + if (SUCCEEDED(hr) && adapter3) { + DXGI_ADAPTER_DESC desc; + hr = adapter3->GetDesc(&desc); + if (SUCCEEDED(hr)) { + // Check for NVIDIA (0x10DE) or AMD (0x1002) + if (desc.VendorId == 0x10DE || desc.VendorId == 0x1002) { + dedicatedAdapter = adapter3; + adapter->Release(); + break; + } + } + adapter3->Release(); + } + adapter->Release(); + } + + if (!dedicatedAdapter) { + std::cerr << "[SystemMonitor] No dedicated NVIDIA/AMD GPU found." << std::endl; + m_gpuMonitoringSupported = false; + return; + } + + m_dxgiAdapter = dedicatedAdapter; + m_gpuMonitoringSupported = true; + updateDirectXGpuStats(); + } + + void updateDirectXGpuStats() { + if (!m_dxgiAdapter) + return; + + DXGI_QUERY_VIDEO_MEMORY_INFO videoMemoryInfo = {}; + HRESULT hr = m_dxgiAdapter->QueryVideoMemoryInfo(0, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, &videoMemoryInfo); + if (SUCCEEDED(hr)) { + m_usedGpuMemory = videoMemoryInfo.CurrentUsage; + + DXGI_ADAPTER_DESC adapterDesc = {}; + hr = m_dxgiAdapter->GetDesc(&adapterDesc); + if (SUCCEEDED(hr)) { + m_totalGpuMemory = static_cast(adapterDesc.DedicatedVideoMemory); + } + else { + m_totalGpuMemory = videoMemoryInfo.Budget; + } + + m_availableGpuMemory = (m_totalGpuMemory > m_usedGpuMemory) ? + m_totalGpuMemory - m_usedGpuMemory : 0; + } + } +#endif +}; diff --git a/include/threadpool.hpp b/include/threadpool.hpp new file mode 100644 index 0000000..db89713 --- /dev/null +++ b/include/threadpool.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(size_t numThreads = std::thread::hardware_concurrency()) { + m_workers.reserve(numThreads); + for (size_t i = 0; i < numThreads; ++i) { + m_workers.emplace_back([this] { + while (true) { + std::function task; + { + std::unique_lock lock(m_queueMutex); + m_condition.wait(lock, [this] { + return m_stop || !m_tasks.empty(); + }); + + if (m_stop && m_tasks.empty()) return; + + task = std::move(m_tasks.front()); + m_tasks.pop(); + } + task(); + } + }); + } + } + + template + auto enqueue(F&& f, Args&&... args) -> std::future::type> { + using return_type = typename std::invoke_result::type; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(m_queueMutex); + if (m_stop) throw std::runtime_error("enqueue on stopped ThreadPool"); + + m_tasks.emplace([task]() { (*task)(); }); + } + m_condition.notify_one(); + return res; + } + + ~ThreadPool() { + { + std::unique_lock lock(m_queueMutex); + m_stop = true; + } + m_condition.notify_all(); + for (std::thread& worker : m_workers) { + if (worker.joinable()) { + worker.join(); + } + } + } + +private: + std::vector m_workers; + std::queue> m_tasks; + + std::mutex m_queueMutex; + std::condition_variable m_condition; + bool m_stop = false; +}; \ No newline at end of file diff --git a/include/ui/chat/chat_history.hpp b/include/ui/chat/chat_history.hpp index 63688e5..4134136 100644 --- a/include/ui/chat/chat_history.hpp +++ b/include/ui/chat/chat_history.hpp @@ -50,7 +50,7 @@ class ChatHistoryRenderer { bubbleBgColorAssistant = ImVec4(0.0f, 0.0f, 0.0f, 0.0f); } - void render(const Chat::ChatHistory& chatHistory, float contentWidth) + void render(const Chat::ChatHistory& chatHistory, float contentWidth, float& paddingX) { const size_t currentMessageCount = chatHistory.messages.size(); const bool newMessageAdded = currentMessageCount > m_lastMessageCount; @@ -60,7 +60,7 @@ class ChatHistoryRenderer { const bool atBottom = (scrollMaxY <= 0.0f) || (scrollY >= scrollMaxY - ChatHistoryConstants::MIN_SCROLL_DIFFERENCE); for (size_t i = 0; i < currentMessageCount; ++i) { - renderMessage(chatHistory.messages[i], static_cast(i), contentWidth); + renderMessage(chatHistory.messages[i], static_cast(i), contentWidth, paddingX); } if (newMessageAdded && atBottom) { @@ -130,7 +130,7 @@ class ChatHistoryRenderer { return dim; } - void renderMessageContent(const Chat::Message& msg, float bubbleWidth, float bubblePadding) + void renderMessageContent(const Chat::Message& msg, float bubbleWidth, float bubblePadding, float& paddingX) { if (msg.role == "user") { ImGui::SetCursorPosX(bubblePadding); @@ -139,6 +139,8 @@ class ChatHistoryRenderer { } ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 24); + ImGui::SetCursorPosX(ImGui::GetCursorPosX() + paddingX); + ImGui::BeginChild("##assistant_message_" + msg.id, { bubbleWidth, 0 }, ImGuiChildFlags_AutoResizeY); ImGui::BeginGroup(); auto segments = parseThinkSegments(msg.content); @@ -196,6 +198,7 @@ class ChatHistoryRenderer { } ImGui::EndGroup(); + ImGui::EndChild(); } static void chatStreamingCallback(const std::string& partialOutput, const float tps, const int jobId, const bool isFinished) { @@ -240,7 +243,7 @@ class ChatHistoryRenderer { // Stop current generation if running. if (modelManager.isCurrentlyGenerating()) { - modelManager.stopJob(chatManager.getCurrentJobId()); + modelManager.stopJob(chatManager.getCurrentJobId(), modelManager.getCurrentModelName().value()); while (true) { @@ -312,7 +315,8 @@ class ChatHistoryRenderer { chatManager.getCurrentChat().value() ); - int jobId = modelManager.startChatCompletionJob(completionParams, chatStreamingCallback); + int jobId = modelManager.startChatCompletionJob(completionParams, chatStreamingCallback, + modelManager.getCurrentModelName().value()); if (!chatManager.setCurrentJobId(jobId)) { std::cerr << "[ChatSection] Failed to set the current job ID.\n"; } @@ -320,9 +324,13 @@ class ChatHistoryRenderer { modelManager.setModelGenerationInProgress(true); } - void renderMetadata(const Chat::Message& msg, int index, float bubbleWidth, float bubblePadding) + void renderMetadata(const Chat::Message& msg, int index, float bubbleWidth, float bubblePadding, float& paddingX) { ImGui::PushStyleColor(ImGuiCol_Text, timestampColor); + if (msg.role == "assistant") + ImGui::SetCursorPosX(ImGui::GetCursorPosX() + paddingX); + + float cursorX = ImGui::GetCursorPosX(); // Timestamp ImGui::TextWrapped("%s", timePointToString(msg.timestamp).c_str()); @@ -337,9 +345,8 @@ class ChatHistoryRenderer { // Copy button ImGui::SameLine(); ImGui::SetCursorPosX( - ImGui::GetCursorPosX() + ImGui::GetContentRegionAvail().x - - (msg.role == "assistant" ? 2 : 1) * - Config::Button::WIDTH - bubblePadding + cursorX + bubbleWidth - + 2 * Config::Button::WIDTH - bubblePadding ); std::vector helperButtons; @@ -375,7 +382,7 @@ class ChatHistoryRenderer { ImGui::PopStyleColor(); } - void renderMessage(const Chat::Message& msg, int index, float contentWidth) + void renderMessage(const Chat::Message& msg, int index, float contentWidth, float& _paddingX /* Padding to center the message */) { const auto [bubbleWidth, bubblePadding, paddingX] = calculateDimensions(msg, contentWidth); @@ -384,7 +391,7 @@ class ChatHistoryRenderer { ? bubbleBgColorUser : bubbleBgColorAssistant); - ImGui::SetCursorPosX(paddingX); + ImGui::SetCursorPosX(paddingX + _paddingX); if (msg.role == "user") { ImVec2 textSize = ImGui::CalcTextSize(msg.content.c_str(), nullptr, true, bubbleWidth - 2 * bubblePadding); @@ -396,9 +403,9 @@ class ChatHistoryRenderer { ImGuiChildFlags_Border | ImGuiChildFlags_AlwaysUseWindowPadding); ImGui::PopStyleVar(); - renderMessageContent(msg, bubbleWidth - 2 * bubblePadding, bubblePadding); + renderMessageContent(msg, bubbleWidth - 2 * bubblePadding, bubblePadding, _paddingX); ImGui::Spacing(); - renderMetadata(msg, index, bubbleWidth, 0); + renderMetadata(msg, index, bubbleWidth, 0, _paddingX); ImGui::EndChild(); } @@ -421,9 +428,9 @@ class ChatHistoryRenderer { ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 12); } - renderMessageContent(msg, bubbleWidth, bubblePadding); + renderMessageContent(msg, bubbleWidth, bubblePadding, _paddingX); ImGui::Spacing(); - renderMetadata(msg, index, bubbleWidth, bubblePadding); + renderMetadata(msg, index, bubbleWidth, bubblePadding, _paddingX); } ImGui::PopStyleColor(); diff --git a/include/ui/chat/chat_history_sidebar.hpp b/include/ui/chat/chat_history_sidebar.hpp index 54eb525..d58e23a 100644 --- a/include/ui/chat/chat_history_sidebar.hpp +++ b/include/ui/chat/chat_history_sidebar.hpp @@ -123,7 +123,7 @@ class ChatHistorySidebar { void render() { ImGuiIO& io = ImGui::GetIO(); - const float sidebarHeight = io.DisplaySize.y - Config::TITLE_BAR_HEIGHT; + const float sidebarHeight = io.DisplaySize.y - Config::TITLE_BAR_HEIGHT - Config::FOOTER_HEIGHT; // Set up the sidebar window. ImGui::SetNextWindowPos(ImVec2(0, Config::TITLE_BAR_HEIGHT), ImGuiCond_Always); diff --git a/include/ui/chat/chat_window.hpp b/include/ui/chat/chat_window.hpp index c89f80a..9936eaf 100644 --- a/include/ui/chat/chat_window.hpp +++ b/include/ui/chat/chat_window.hpp @@ -5,7 +5,7 @@ #include "chat_history.hpp" #include "ui/widgets.hpp" #include "ui/markdown.hpp" -#include "ui/chat/model_manager_modal.hpp" +#include "ui/model_manager_modal.hpp" #include "chat/chat_manager.hpp" #include "model/preset_manager.hpp" #include "model/model_manager.hpp" @@ -187,14 +187,17 @@ class ChatWindow { sendButtonConfig.tooltip = "Send Message"; inputPlaceholderText = "Type a message and press Enter to send (Ctrl+Enter or Shift+Enter for new line)"; + + // Initialize auto-scroll state + m_shouldAutoScroll = true; + m_wasAtBottom = true; + m_lastContentHeight = 0.0f; } - // Render the chat window. This method computes layout values and then renders - // the cached widgets, updating only the dynamic properties. void render(float leftSidebarWidth, float rightSidebarWidth) { ImGuiIO& io = ImGui::GetIO(); ImVec2 windowSize = ImVec2(io.DisplaySize.x - rightSidebarWidth - leftSidebarWidth, - io.DisplaySize.y - Config::TITLE_BAR_HEIGHT); + io.DisplaySize.y - Config::TITLE_BAR_HEIGHT - Config::FOOTER_HEIGHT); ImGui::SetNextWindowPos(ImVec2(leftSidebarWidth, Config::TITLE_BAR_HEIGHT), ImGuiCond_Always); ImGui::SetNextWindowSize(windowSize, ImGuiCond_Always); @@ -221,22 +224,16 @@ class ChatWindow { // Render the clear chat modal. clearChatModal.render(); - // Render the rename chat modal. + // Render the rename chat modal. renameChatModal.render(); // Spacing between widgets. for (int i = 0; i < 4; ++i) ImGui::Spacing(); - if (paddingX > 0.0F) - ImGui::SetCursorPosX(ImGui::GetCursorPosX() + paddingX); - // Render the chat history region. float availableHeight = ImGui::GetContentRegionAvail().y - m_inputHeight - Config::BOTTOM_MARGIN; - ImGui::BeginChild("ChatHistoryRegion", ImVec2(contentWidth, availableHeight), false, ImGuiWindowFlags_NoScrollbar); - if (auto chat = Chat::ChatManager::getInstance().getCurrentChat()) - chatHistoryRenderer.render(*chat, contentWidth); - ImGui::EndChild(); + renderChatHistoryWithAutoScroll(contentWidth, availableHeight, paddingX); ImGui::Spacing(); float inputFieldPaddingX = (availableWidth - contentWidth) / 2.0F; @@ -249,6 +246,53 @@ class ChatWindow { } private: + void renderChatHistoryWithAutoScroll(float contentWidth, float availableHeight, float paddingX) { + const char* chatHistoryId = "ChatHistoryRegion"; + + // Begin the child window for chat history + ImGui::PushStyleColor(ImGuiCol_ScrollbarBg, ImVec4(0, 0, 0, 0)); + ImGui::BeginChild(chatHistoryId, ImVec2(-1, availableHeight), false); + ImGui::PopStyleColor(); + + // Check if we were at the bottom before rendering new content + float scrollY = ImGui::GetScrollY(); + float maxScrollY = ImGui::GetScrollMaxY(); + m_wasAtBottom = (maxScrollY <= 0.0f) || (scrollY >= maxScrollY - 1.0f); + + // Render the chat content + if (auto chat = Chat::ChatManager::getInstance().getCurrentChat()) + { + chatHistoryRenderer.render(*chat, contentWidth, paddingX); + } + + // Calculate if content height has changed (new content was added) + float currentMaxScrollY = ImGui::GetScrollMaxY(); + bool contentHeightChanged = currentMaxScrollY != m_lastContentHeight; + m_lastContentHeight = currentMaxScrollY; + + if (ImGui::IsMouseDragging(ImGuiMouseButton_Left) || ImGui::GetIO().MouseWheel != 0) { + // User is actively scrolling, check if they're at the bottom + float newScrollY = ImGui::GetScrollY(); + float newMaxScrollY = ImGui::GetScrollMaxY(); + bool nowAtBottom = (newMaxScrollY <= 0.0f) || (newScrollY >= newMaxScrollY - 1.0f); + + // Update auto-scroll state based on whether user is at bottom + m_shouldAutoScroll = nowAtBottom; + } + + bool shouldScrollToBottom = m_shouldAutoScroll && (m_wasAtBottom || contentHeightChanged); + + // Check if we're generating content + bool isGenerating = Model::ModelManager::getInstance().isCurrentlyGenerating(); + + // Force scroll to bottom if generating and should auto-scroll + if (shouldScrollToBottom || (isGenerating && m_shouldAutoScroll)) { + ImGui::SetScrollHereY(1.0f); // 1.0f means align to bottom + } + + ImGui::EndChild(); + } + static void chatStreamingCallback(const std::string& partialOutput, const float tps, const int jobId, const bool isFinished) { auto& chatManager = Chat::ChatManager::getInstance(); auto& modelManager = Model::ModelManager::getInstance(); @@ -305,7 +349,7 @@ class ChatWindow { auto& chatManager = Chat::ChatManager::getInstance(); // Generate the title (synchronous call) - CompletionResult titleResult = modelManager.chatCompleteSync(titleParams, false); + CompletionResult titleResult = modelManager.chatCompleteSync(titleParams, modelManager.getCurrentModelName().value(), false); if (!titleResult.text.empty()) { // Clean up the generated title @@ -321,6 +365,23 @@ class ChatWindow { s.erase(pos, titlePrefix.length()); } + // Remove any thinking tags (e.g., ...) + const std::string thinkStart = ""; + const std::string thinkEnd = ""; + size_t startPos = s.find(thinkStart); + while (startPos != std::string::npos) { + size_t endPos = s.find(thinkEnd, startPos + thinkStart.length()); + if (endPos != std::string::npos) { + s.erase(startPos, endPos + thinkEnd.length() - startPos); + } + else { + // If no matching end tag, remove from start tag to end of string + s.erase(startPos); + break; + } + startPos = s.find(thinkStart, startPos); + } + // Trim whitespace s.erase(0, s.find_first_not_of(" \t\n\r")); if (!s.empty()) { @@ -334,7 +395,7 @@ class ChatWindow { if (!newTitle.empty()) { if (!chatManager.renameCurrentChat(newTitle).get()) { - std::cerr << "[ChatSection] Failed to rename chat to: " << newTitle << "\n"; + std::cerr << "[ChatSection] Failed to rename chat to: " << newTitle << "\n"; } } } @@ -343,13 +404,13 @@ class ChatWindow { // Render the row of buttons that allow the user to switch models or clear chat. void renderChatFeatureButtons(float baseX, float baseY) { - Model::ModelManager& modelManager = Model::ModelManager::getInstance(); + Model::ModelManager& modelManager = Model::ModelManager::getInstance(); // Update the open-model manager button’s label dynamically. openModelManagerConfig.label = modelManager.getCurrentModelName().value_or("Select Model"); - openModelManagerConfig.tooltip = - modelManager.getCurrentModelName().value_or("Select Model"); + openModelManagerConfig.tooltip = + modelManager.getCurrentModelName().value_or("Select Model"); if (modelManager.isLoadInProgress()) { @@ -358,7 +419,7 @@ class ChatWindow { if (modelManager.isModelLoaded()) { - openModelManagerConfig.icon = ICON_CI_SPARKLE_FILLED; + openModelManagerConfig.icon = ICON_CI_SPARKLE_FILLED; } std::vector buttons = { openModelManagerConfig, clearChatButtonConfig }; @@ -398,13 +459,17 @@ class ChatWindow { buildChatCompletionParameters(currentChat, message); auto& modelManager = Model::ModelManager::getInstance(); - int jobId = modelManager.startChatCompletionJob(completionParams, chatStreamingCallback); + int jobId = modelManager.startChatCompletionJob(completionParams, chatStreamingCallback, + modelManager.getCurrentModelName().value()); if (!chatManager.setCurrentJobId(jobId)) { std::cerr << "[ChatSection] Failed to set the current job ID.\n"; } modelManager.setModelGenerationInProgress(true); + // Ensure auto-scroll is enabled when starting a new generation + m_shouldAutoScroll = true; + // If this is the first message, generate a title for the chat if (isFirstMessage) { generateChatTitle(message); @@ -462,20 +527,21 @@ class ChatWindow { sendButtonConfig.tooltip = "Stop generation"; sendButtonConfig.onClick = []() { Model::ModelManager::getInstance().stopJob( - Chat::ChatManager::getInstance().getCurrentJobId() + Chat::ChatManager::getInstance().getCurrentJobId(), + Model::ModelManager::getInstance().getCurrentModelName().value() ); }; sendButtonConfig.state = ButtonState::NORMAL; } - // Disable the send button and input processing if no model is loaded. - if (!modelManager.isModelLoaded()) { + // Disable the send button and input processing if no model is loaded. + if (!modelManager.isModelLoaded()) { inputConfig.flags = ImGuiInputTextFlags_CtrlEnterForNewLine | ImGuiInputTextFlags_ShiftEnterForNewLine; inputConfig.processInput = nullptr; - sendButtonConfig.state = ButtonState::DISABLED; - } + sendButtonConfig.state = ButtonState::DISABLED; + } } void drawInputFieldBackground(const float width, const float height) { @@ -533,6 +599,11 @@ class ChatWindow { bool focusInputField = true; float m_inputHeight = Config::INPUT_HEIGHT; + // Auto-scroll state variables + bool m_shouldAutoScroll; + bool m_wasAtBottom; + float m_lastContentHeight; + // Child components. ModelManagerModal modelManagerModal; RenameChatModalComponent renameChatModal; diff --git a/include/ui/chat/model_manager_modal.hpp b/include/ui/chat/model_manager_modal.hpp deleted file mode 100644 index 8aabaac..0000000 --- a/include/ui/chat/model_manager_modal.hpp +++ /dev/null @@ -1,686 +0,0 @@ -#pragma once - -#include "imgui.h" -#include "ui/widgets.hpp" -#include "ui/markdown.hpp" -#include "model/model_manager.hpp" -#include "ui/fonts.hpp" -#include -#include -#include -#include - -namespace ModelManagerConstants { - constexpr float cardWidth = 200.0f; - constexpr float cardHeight = 220.0f; - constexpr float cardSpacing = 10.0f; - constexpr float padding = 16.0f; - constexpr float modalVerticalScale = 0.9f; - constexpr float sectionSpacing = 20.0f; - constexpr float sectionHeaderHeight = 30.0f; -} - -class DeleteModelModalComponent { -public: - DeleteModelModalComponent() { - ButtonConfig cancelButton; - cancelButton.id = "##cancelDeleteModel"; - cancelButton.label = "Cancel"; - cancelButton.backgroundColor = RGBAToImVec4(34, 34, 34, 255); - cancelButton.hoverColor = RGBAToImVec4(53, 132, 228, 255); - cancelButton.activeColor = RGBAToImVec4(26, 95, 180, 255); - cancelButton.textColor = RGBAToImVec4(255, 255, 255, 255); - cancelButton.size = ImVec2(130, 0); - cancelButton.onClick = []() { ImGui::CloseCurrentPopup(); }; - - ButtonConfig confirmButton; - confirmButton.id = "##confirmDeleteModel"; - confirmButton.label = "Confirm"; - confirmButton.backgroundColor = RGBAToImVec4(26, 95, 180, 255); - confirmButton.hoverColor = RGBAToImVec4(53, 132, 228, 255); - confirmButton.activeColor = RGBAToImVec4(26, 95, 180, 255); - confirmButton.size = ImVec2(130, 0); - confirmButton.onClick = [this]() { - if (m_index != -1 && !m_variant.empty()) { - Model::ModelManager::getInstance().deleteDownloadedModel(m_index, m_variant); - ImGui::CloseCurrentPopup(); - } - }; - - buttons.push_back(cancelButton); - buttons.push_back(confirmButton); - } - - void setModel(int index, const std::string& variant) { - m_index = index; - m_variant = variant; - } - - void render(bool& openModal) { - if (m_index == -1 || m_variant.empty()) { - openModal = false; - return; - } - - ModalConfig config{ - "Confirm Delete Model", - "Confirm Delete Model", - ImVec2(300, 96), - [this]() { - Button::renderGroup(buttons, 16, ImGui::GetCursorPosY() + 8); - }, - openModal - }; - config.padding = ImVec2(16.0f, 8.0f); - ModalWindow::render(config); - - if (!ImGui::IsPopupOpen(config.id.c_str())) { - openModal = false; - m_index = -1; - m_variant.clear(); - } - } - -private: - int m_index = -1; - std::string m_variant; - std::vector buttons; -}; - -class ModelCardRenderer { -public: - ModelCardRenderer(int index, const Model::ModelData& modelData, - std::function onDeleteRequested, std::string id = "") - : m_index(index), m_model(modelData), m_onDeleteRequested(onDeleteRequested), m_id(id) - { - selectButton.id = "##select" + std::to_string(m_index) + m_id; - selectButton.size = ImVec2(ModelManagerConstants::cardWidth - 18, 0); - - deleteButton.id = "##delete" + std::to_string(m_index) + m_id; - deleteButton.size = ImVec2(24, 0); - deleteButton.backgroundColor = RGBAToImVec4(200, 50, 50, 255); - deleteButton.hoverColor = RGBAToImVec4(220, 70, 70, 255); - deleteButton.activeColor = RGBAToImVec4(200, 50, 50, 255); - deleteButton.icon = ICON_CI_TRASH; - deleteButton.onClick = [this]() { - std::string currentVariant = Model::ModelManager::getInstance().getCurrentVariantForModel(m_model.name); - m_onDeleteRequested(m_index, currentVariant); - }; - - authorLabel.id = "##modelAuthor" + std::to_string(m_index) + m_id; - authorLabel.label = m_model.author; - authorLabel.size = ImVec2(0, 0); - authorLabel.fontType = FontsManager::ITALIC; - authorLabel.fontSize = FontsManager::SM; - authorLabel.alignment = Alignment::LEFT; - - nameLabel.id = "##modelName" + std::to_string(m_index) + m_id; - nameLabel.label = m_model.name; - nameLabel.size = ImVec2(0, 0); - nameLabel.fontType = FontsManager::BOLD; - nameLabel.fontSize = FontsManager::MD; - nameLabel.alignment = Alignment::LEFT; - } - - void render() { - auto& manager = Model::ModelManager::getInstance(); - std::string currentVariant = manager.getCurrentVariantForModel(m_model.name); - - ImGui::BeginGroup(); - ImGui::PushStyleColor(ImGuiCol_ChildBg, RGBAToImVec4(26, 26, 26, 255)); - ImGui::PushStyleVar(ImGuiStyleVar_ChildRounding, 8.0f); - - std::string childName = "ModelCard" + std::to_string(m_index) + m_id; - ImGui::BeginChild(childName.c_str(), ImVec2(ModelManagerConstants::cardWidth, ModelManagerConstants::cardHeight), true); - - renderHeader(); - ImGui::Spacing(); - renderVariantOptions(currentVariant); - - ImGui::SetCursorPosY(ModelManagerConstants::cardHeight - 35); - - bool isSelected = (m_model.name == manager.getCurrentModelName() && - currentVariant == manager.getCurrentVariantType()); - bool isDownloaded = manager.isModelDownloaded(m_index, currentVariant); - - if (!isDownloaded) { - double progress = manager.getModelDownloadProgress(m_index, currentVariant); - if (progress > 0.0) { - selectButton.label = "Cancel"; - selectButton.backgroundColor = RGBAToImVec4(200, 50, 50, 255); - selectButton.hoverColor = RGBAToImVec4(220, 70, 70, 255); - selectButton.activeColor = RGBAToImVec4(200, 50, 50, 255); - selectButton.icon = ICON_CI_CLOSE; - selectButton.onClick = [this, currentVariant]() { - Model::ModelManager::getInstance().cancelDownload(m_index, currentVariant); - }; - - ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 12); - float fraction = static_cast(progress) / 100.0f; - ProgressBar::render(fraction, ImVec2(ModelManagerConstants::cardWidth - 18, 6)); - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 4); - } - else { - selectButton.label = "Download"; - selectButton.backgroundColor = RGBAToImVec4(26, 95, 180, 255); - selectButton.hoverColor = RGBAToImVec4(53, 132, 228, 255); - selectButton.activeColor = RGBAToImVec4(26, 95, 180, 255); - selectButton.icon = ICON_CI_CLOUD_DOWNLOAD; - selectButton.borderSize = 1.0f; - selectButton.onClick = [this, currentVariant]() { - Model::ModelManager::getInstance().setPreferredVariant(m_model.name, currentVariant); - Model::ModelManager::getInstance().downloadModel(m_index, currentVariant); - }; - } - } - else { - bool isLoadingSelected = isSelected && Model::ModelManager::getInstance().isLoadInProgress(); - bool isUnloading = isSelected && Model::ModelManager::getInstance().isUnloadInProgress(); - - // Configure button label and base state - if (isLoadingSelected || isUnloading) { - selectButton.label = isLoadingSelected ? "Loading Model..." : "Unloading Model..."; - selectButton.state = ButtonState::DISABLED; - selectButton.icon = ""; // Clear any existing icon - selectButton.borderSize = 0.0f; // Remove border - } - else { - selectButton.label = isSelected ? "Selected" : "Select"; - } - - // Base styling (applies to all states) - selectButton.backgroundColor = RGBAToImVec4(34, 34, 34, 255); - - // Disabled state for non-selected loading - if (!isSelected && Model::ModelManager::getInstance().isLoadInProgress()) { - selectButton.state = ButtonState::DISABLED; - } - - // Common properties - selectButton.onClick = [this]() { - std::string variant = Model::ModelManager::getInstance().getCurrentVariantForModel(m_model.name); - Model::ModelManager::getInstance().switchModel(m_model.name, variant); - }; - selectButton.size = ImVec2(ModelManagerConstants::cardWidth - 18 - 5 - 24, 0); - - // Selected state styling (only if not loading) - if (isSelected && !isLoadingSelected) { - selectButton.icon = ICON_CI_DEBUG_DISCONNECT; - selectButton.borderColor = RGBAToImVec4(172, 131, 255, 255 / 4); - selectButton.borderSize = 1.0f; - selectButton.state = ButtonState::NORMAL; - selectButton.tooltip = "Click to unload model from memory"; - selectButton.onClick = [this]() { - Model::ModelManager::getInstance().unloadModel(); - }; - } - - // Add progress bar if in loading-selected state - if (isLoadingSelected || isUnloading) { - ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 12); - ProgressBar::render(0, ImVec2(ModelManagerConstants::cardWidth - 18, 6)); - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 4); - } - } - - Button::render(selectButton); - - if (isDownloaded) { - ImGui::SameLine(); - ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 2); - ImGui::SetCursorPosX(ImGui::GetCursorPosX() + ImGui::GetContentRegionAvail().x - 24 - 2); - - if (isSelected && Model::ModelManager::getInstance().isLoadInProgress()) - deleteButton.state = ButtonState::DISABLED; - else - deleteButton.state = ButtonState::NORMAL; - - Button::render(deleteButton); - } - - ImGui::EndChild(); - if (ImGui::IsItemHovered() || isSelected) { - ImVec2 min = ImGui::GetItemRectMin(); - ImVec2 max = ImGui::GetItemRectMax(); - ImU32 borderColor = IM_COL32(172, 131, 255, 255 / 2); - ImGui::GetWindowDrawList()->AddRect(min, max, borderColor, 8.0f, 0, 1.0f); - } - - ImGui::PopStyleVar(); - ImGui::PopStyleColor(); - ImGui::EndGroup(); - } - -private: - int m_index; - std::string m_id; - const Model::ModelData& m_model; - std::function m_onDeleteRequested; - - void renderHeader() { - Label::render(authorLabel); - Label::render(nameLabel); - } - - void renderVariantOptions(const std::string& currentVariant) { - LabelConfig variantLabel; - variantLabel.id = "##variantLabel" + std::to_string(m_index); - variantLabel.label = "Model Variants"; - variantLabel.size = ImVec2(0, 0); - variantLabel.fontType = FontsManager::REGULAR; - variantLabel.fontSize = FontsManager::SM; - variantLabel.alignment = Alignment::LEFT; - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 2); - Label::render(variantLabel); - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 4); - - // Calculate the height for the scrollable area - // Card height minus header space minus button space at bottom - const float variantAreaHeight = 100.0f; // Adjust this value based on your layout needs - - // Create a scrollable child window for variants - ImGui::BeginChild(("##VariantScroll" + std::to_string(m_index)).c_str(), - ImVec2(ModelManagerConstants::cardWidth - 18, variantAreaHeight), - false); - - // Helper function to render a single variant option - auto renderVariant = [this, ¤tVariant](const std::string& variant) { - ButtonConfig btnConfig; - btnConfig.id = "##" + variant + std::to_string(m_index); - btnConfig.icon = (currentVariant == variant) ? ICON_CI_CHECK : ICON_CI_CLOSE; - btnConfig.textColor = (currentVariant != variant) ? RGBAToImVec4(34, 34, 34, 255) : ImVec4(1, 1, 1, 1); - btnConfig.fontSize = FontsManager::SM; - btnConfig.size = ImVec2(24, 0); - btnConfig.backgroundColor = RGBAToImVec4(34, 34, 34, 255); - btnConfig.onClick = [variant, this]() { - Model::ModelManager::getInstance().setPreferredVariant(m_model.name, variant); - }; - ImGui::SetCursorPosX(ImGui::GetCursorPosX() + 4); - Button::render(btnConfig); - - ImGui::SameLine(0.0f, 4.0f); - LabelConfig variantLabel; - variantLabel.id = "##" + variant + "Label" + std::to_string(m_index); - variantLabel.label = variant; - variantLabel.size = ImVec2(0, 0); - variantLabel.fontType = FontsManager::REGULAR; - variantLabel.fontSize = FontsManager::SM; - variantLabel.alignment = Alignment::LEFT; - ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 6); - Label::render(variantLabel); - }; - - // Iterate through all variants in the model - for (const auto& [variant, variantData] : m_model.variants) { - // For each variant, render a button - renderVariant(variant); - ImGui::Spacing(); - } - - // End the scrollable area - ImGui::EndChild(); - } - - ButtonConfig deleteButton; - ButtonConfig selectButton; - LabelConfig nameLabel; - LabelConfig authorLabel; -}; - -struct SortableModel { - int index; - std::string name; - - bool operator<(const SortableModel& other) const { - return name < other.name; - } -}; - -// TODO: Fix the nested modal -// when i tried to make the delete modal rendered on top of the model modal, it simply -// either didn't show up at all, or the model modal closed, and the entire application -// freezed. I tried to fix it, but I couldn't find a solution. I'm leaving it as is for now. -// Time wasted: 18 hours. -class ModelManagerModal { -public: - ModelManagerModal() : m_searchText(""), m_shouldFocusSearch(false) {} - - void render(bool& showDialog) { - auto& manager = Model::ModelManager::getInstance(); - - // Update sorted models when: - // - The modal is opened for the first time - // - A model is downloaded, deleted, or its status changed - bool needsUpdate = false; - - if (showDialog && !m_wasShowing) { - // Modal just opened - refresh the model list - needsUpdate = true; - // Focus the search field when the modal is opened - m_shouldFocusSearch = true; - } - - // Check for changes in download status - const auto& models = manager.getModels(); - if (models.size() != m_lastModelCount) { - // The model count changed - needsUpdate = true; - } - - // Check for changes in downloaded status - if (!needsUpdate) { - std::unordered_set currentDownloaded; - - for (size_t i = 0; i < models.size(); ++i) { - // Check if ANY variant is downloaded instead of just the current one - if (manager.isAnyVariantDownloaded(static_cast(i))) { - currentDownloaded.insert(models[i].name); // Don't need to add variant to the key - } - } - - if (currentDownloaded != m_lastDownloadedStatus) { - needsUpdate = true; - m_lastDownloadedStatus = std::move(currentDownloaded); - } - } - - if (needsUpdate) { - updateSortedModels(); - m_lastModelCount = models.size(); - filterModels(); // Apply the current search filter to the updated models - } - - m_wasShowing = showDialog; - - ImVec2 windowSize = ImGui::GetWindowSize(); - if (windowSize.x == 0) windowSize = ImGui::GetMainViewport()->Size; - const float targetWidth = windowSize.x; - float availableWidth = targetWidth - (2 * ModelManagerConstants::padding); - - int numCards = static_cast(availableWidth / (ModelManagerConstants::cardWidth + ModelManagerConstants::cardSpacing)); - float modalWidth = (numCards * (ModelManagerConstants::cardWidth + ModelManagerConstants::cardSpacing)) + (2 * ModelManagerConstants::padding); - if (targetWidth - modalWidth > (ModelManagerConstants::cardWidth + ModelManagerConstants::cardSpacing) * 0.5f) { - ++numCards; - modalWidth = (numCards * (ModelManagerConstants::cardWidth + ModelManagerConstants::cardSpacing)) + (2 * ModelManagerConstants::padding); - } - ImVec2 modalSize = ImVec2(modalWidth, windowSize.y * ModelManagerConstants::modalVerticalScale); - - auto renderCards = [numCards, this]() { - auto& manager = Model::ModelManager::getInstance(); - const auto& models = manager.getModels(); - - // Render search field at the top - renderSearchField(); - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + ModelManagerConstants::sectionSpacing); - - LabelConfig downloadedSectionLabel; - downloadedSectionLabel.id = "##downloadedModelsHeader"; - downloadedSectionLabel.label = "Downloaded Models"; - downloadedSectionLabel.size = ImVec2(0, 0); - downloadedSectionLabel.fontSize = FontsManager::LG; - downloadedSectionLabel.alignment = Alignment::LEFT; - - ImGui::SetCursorPos(ImVec2(ModelManagerConstants::padding, ImGui::GetCursorPosY())); - Label::render(downloadedSectionLabel); - - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 10.0f); - - // Count downloaded models and check if we have any - bool hasDownloadedModels = false; - int downloadedCardCount = 0; - - // First pass to check if we have any downloaded models - for (const auto& sortableModel : m_filteredModels) { - // Check if ANY variant is downloaded instead of just current variant - if (manager.isAnyVariantDownloaded(sortableModel.index)) { - hasDownloadedModels = true; - break; - } - } - - // Render downloaded models - if (hasDownloadedModels) { - for (const auto& sortableModel : m_filteredModels) { - // Check if ANY variant is downloaded instead of just current variant - if (manager.isAnyVariantDownloaded(sortableModel.index)) { - if (downloadedCardCount % numCards == 0) { - ImGui::SetCursorPos(ImVec2(ModelManagerConstants::padding, - ImGui::GetCursorPosY() + (downloadedCardCount > 0 ? ModelManagerConstants::cardSpacing : 0))); - } - - ModelCardRenderer card(sortableModel.index, models[sortableModel.index], - [this](int index, const std::string& variant) { - m_deleteModal.setModel(index, variant); - m_deleteModalOpen = true; - }, "downloaded"); - card.render(); - - if ((downloadedCardCount + 1) % numCards != 0) { - ImGui::SameLine(0.0f, ModelManagerConstants::cardSpacing); - } - - downloadedCardCount++; - } - } - - // Add spacing before the next section - if (downloadedCardCount % numCards != 0) { - ImGui::NewLine(); - } - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + ModelManagerConstants::sectionSpacing); - } - else { - // Show a message if no downloaded models - LabelConfig noModelsLabel; - noModelsLabel.id = "##noDownloadedModels"; - noModelsLabel.label = m_searchText.empty() ? - "No downloaded models yet. Download models from the section below." : - "No downloaded models match your search. Try a different search term."; - noModelsLabel.size = ImVec2(0, 0); - noModelsLabel.fontType = FontsManager::ITALIC; - noModelsLabel.fontSize = FontsManager::MD; - noModelsLabel.alignment = Alignment::LEFT; - - ImGui::SetCursorPosX(ModelManagerConstants::padding); - Label::render(noModelsLabel); - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + ModelManagerConstants::sectionSpacing); - } - - // Separator between sections - ImGui::SetCursorPosX(ModelManagerConstants::padding); - ImGui::PushStyleColor(ImGuiCol_Separator, ImVec4(0.3f, 0.3f, 0.3f, 0.5f)); - ImGui::Separator(); - ImGui::PopStyleColor(); - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 10.0f); - - // Render "Available Models" section header - LabelConfig availableSectionLabel; - availableSectionLabel.id = "##availableModelsHeader"; - availableSectionLabel.label = "Available Models"; - availableSectionLabel.size = ImVec2(0, 0); - availableSectionLabel.fontSize = FontsManager::LG; - availableSectionLabel.alignment = Alignment::LEFT; - - ImGui::SetCursorPosX(ModelManagerConstants::padding); - Label::render(availableSectionLabel); - ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 10.0f); - - // Check if we have any available models that match the search - if (m_filteredModels.empty() && !m_searchText.empty()) { - LabelConfig noModelsLabel; - noModelsLabel.id = "##noAvailableModels"; - noModelsLabel.label = "No models match your search. Try a different search term."; - noModelsLabel.size = ImVec2(0, 0); - noModelsLabel.fontType = FontsManager::ITALIC; - noModelsLabel.fontSize = FontsManager::MD; - noModelsLabel.alignment = Alignment::LEFT; - - ImGui::SetCursorPosX(ModelManagerConstants::padding); - Label::render(noModelsLabel); - } - else { - // Render all models (available for download) - for (size_t i = 0; i < m_filteredModels.size(); ++i) { - if (i % numCards == 0) { - ImGui::SetCursorPos(ImVec2(ModelManagerConstants::padding, - ImGui::GetCursorPosY() + (i > 0 ? ModelManagerConstants::cardSpacing : 0))); - } - - ModelCardRenderer card(m_filteredModels[i].index, models[m_filteredModels[i].index], - [this](int index, const std::string& variant) { - m_deleteModal.setModel(index, variant); - m_deleteModalOpen = true; - }); - card.render(); - - if ((i + 1) % numCards != 0 && i < m_filteredModels.size() - 1) { - ImGui::SameLine(0.0f, ModelManagerConstants::cardSpacing); - } - } - } - }; - - ModalConfig config{ - "Model Manager", - "Model Manager", - modalSize, - renderCards, - showDialog - }; - config.padding = ImVec2(ModelManagerConstants::padding, 8.0f); - ModalWindow::render(config); - - // Render the delete modal if it's open. - if (m_deleteModalOpen) { - m_deleteModal.render(m_deleteModalOpen); - - // Mark for update on next frame after deletion - if (!m_deleteModalOpen && m_wasDeleteModalOpen) { - m_needsUpdateAfterDelete = true; - } - } - - if (m_wasDeleteModalOpen && !m_deleteModalOpen) { - showDialog = true; - ImGui::OpenPopup(config.id.c_str()); - } - - if (m_needsUpdateAfterDelete && !m_deleteModalOpen) { - updateSortedModels(); - filterModels(); // Apply search filter after updating models - m_needsUpdateAfterDelete = false; - } - - m_wasDeleteModalOpen = m_deleteModalOpen; - - if (!ImGui::IsPopupOpen(config.id.c_str())) { - showDialog = false; - } - } - -private: - DeleteModelModalComponent m_deleteModal; - bool m_deleteModalOpen = false; - bool m_wasDeleteModalOpen = false; - bool m_wasShowing = false; - bool m_needsUpdateAfterDelete = false; - size_t m_lastModelCount = 0; - std::unordered_set m_lastDownloadedStatus; - std::vector m_sortedModels; - std::vector m_filteredModels; // New: Filtered list of models based on search - - // Search related variables - std::string m_searchText; - bool m_shouldFocusSearch; - - void updateSortedModels() { - auto& manager = Model::ModelManager::getInstance(); - const auto& models = manager.getModels(); - - // Clear and rebuild the sorted model list - m_sortedModels.clear(); - m_sortedModels.reserve(models.size()); - - for (size_t i = 0; i < models.size(); ++i) { - // Store the index and name directly, avoiding storing pointers - m_sortedModels.push_back({ static_cast(i), models[i].name }); - } - - // Sort models alphabetically by name - std::sort(m_sortedModels.begin(), m_sortedModels.end()); - - // Initialize filtered models with all models when sort is updated - filterModels(); - } - - // Filter models based on search text - void filterModels() { - m_filteredModels.clear(); - auto& manager = Model::ModelManager::getInstance(); - const auto& models = manager.getModels(); - - if (m_searchText.empty()) { - // If no search term, show all models - m_filteredModels = m_sortedModels; - return; - } - - // Convert search text to lowercase for case-insensitive comparison - std::string searchLower = m_searchText; - std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), - [](unsigned char c) { return std::tolower(c); }); - - // Filter models based on name OR author containing the search text - for (const auto& model : m_sortedModels) { - // Get the model data using the stored index - const auto& modelData = models[model.index]; - - // Convert name and author to lowercase for case-insensitive comparison - std::string nameLower = modelData.name; - std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), - [](unsigned char c) { return std::tolower(c); }); - - std::string authorLower = modelData.author; - std::transform(authorLower.begin(), authorLower.end(), authorLower.begin(), - [](unsigned char c) { return std::tolower(c); }); - - // Add model to filtered results if either name OR author contains the search text - if (nameLower.find(searchLower) != std::string::npos || - authorLower.find(searchLower) != std::string::npos) { - m_filteredModels.push_back(model); - } - } - } - - // New method: Render search field - void renderSearchField() { - ImGui::SetCursorPosX(ModelManagerConstants::padding); - - // Create and configure search input field - InputFieldConfig searchConfig( - "##modelSearch", - ImVec2(ImGui::GetContentRegionAvail().x, 32.0f), - m_searchText, - m_shouldFocusSearch - ); - searchConfig.placeholderText = "Search models..."; - searchConfig.processInput = [this](const std::string& text) { - // No need to handle submission specifically as we'll filter on every change - }; - - // Style the search field - searchConfig.backgroundColor = RGBAToImVec4(34, 34, 34, 255); - searchConfig.hoverColor = RGBAToImVec4(44, 44, 44, 255); - searchConfig.activeColor = RGBAToImVec4(54, 54, 54, 255); - - // Render the search field - InputField::render(searchConfig); - - // Filter models whenever search text changes - static std::string lastSearch; - if (lastSearch != m_searchText) { - lastSearch = m_searchText; - filterModels(); - } - } -}; \ No newline at end of file diff --git a/include/ui/chat/preset_sidebar.hpp b/include/ui/chat/preset_sidebar.hpp index 9c8435a..2b71a4a 100644 --- a/include/ui/chat/preset_sidebar.hpp +++ b/include/ui/chat/preset_sidebar.hpp @@ -2,6 +2,7 @@ #include "imgui.h" #include "model/preset_manager.hpp" +#include "system_prompt_modal.hpp" #include "ui/widgets.hpp" #include "config.hpp" #include "nfd.h" @@ -167,10 +168,12 @@ class PresetSelectionComponent { class SamplingSettingsComponent { public: - // Takes sidebarWidth by reference. - SamplingSettingsComponent(float& sidebarWidth, bool& focusSystemPrompt) - : m_sidebarWidth(sidebarWidth), m_focusSystemPrompt(focusSystemPrompt) + // Takes sidebarWidth by reference and sharedSystemPromptBuffer by reference + SamplingSettingsComponent(float& sidebarWidth, bool& focusSystemPrompt, std::string& sharedSystemPromptBuffer) + : m_sidebarWidth(sidebarWidth), m_focusSystemPrompt(focusSystemPrompt), m_sharedSystemPromptBuffer(sharedSystemPromptBuffer) { + // Initialize the edit button handler + m_onEditSystemPromptRequested = []() {}; } void render() { @@ -178,34 +181,64 @@ class SamplingSettingsComponent { if (!currentPresetOpt) return; auto& currentPreset = currentPresetOpt->get(); - // Create a temporary buffer with sufficient capacity - static std::string tempSystemPrompt(Config::InputField::TEXT_SIZE, '\0'); - - // On first render or when preset changes, copy current value to the buffer + // Sync the shared buffer with the current preset on first load or when preset changes static int lastPresetId = -1; if (lastPresetId != currentPreset.id) { - tempSystemPrompt = currentPreset.systemPrompt; + m_sharedSystemPromptBuffer = currentPreset.systemPrompt; lastPresetId = currentPreset.id; } - // Render the system prompt label and multi-line input field + // Render the system prompt label and edit button ImGui::Spacing(); ImGui::Spacing(); + + // Create a row for the label and button + ImGui::BeginGroup(); Label::render(m_systemPromptLabel); + + // Calculate position for the edit button + float labelWidth = ImGui::CalcTextSize(m_systemPromptLabel.label.c_str()).x + + Config::Icon::DEFAULT_FONT_SIZE + m_systemPromptLabel.gap.value(); + + ImGui::SameLine(); + + // Position the edit button to edge of the sidebar + ImGui::SetCursorPosX(m_sidebarWidth - 38); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 4); + + // Create edit button + ButtonConfig editButtonConfig; + editButtonConfig.id = "##editsystemprompt"; + editButtonConfig.icon = ICON_CI_EDIT; + editButtonConfig.size = ImVec2(24, 24); + editButtonConfig.alignment = Alignment::CENTER; + editButtonConfig.backgroundColor = Config::Color::TRANSPARENT_COL; + editButtonConfig.hoverColor = Config::Color::SECONDARY; + editButtonConfig.activeColor = Config::Color::PRIMARY; + editButtonConfig.tooltip = "Edit System Prompt in Modal"; + editButtonConfig.onClick = [&]() { + // Call the callback when edit button is clicked + if (m_onEditSystemPromptRequested) { + m_onEditSystemPromptRequested(); + } + }; + + Button::render(editButtonConfig); + ImGui::EndGroup(); + ImGui::Spacing(); ImGui::Spacing(); + // Use the shared buffer for the input field InputFieldConfig inputConfig( "##systemprompt", ImVec2(m_sidebarWidth - 20, 100), - tempSystemPrompt, // Use the temporary buffer instead + m_sharedSystemPromptBuffer, m_focusSystemPrompt ); inputConfig.placeholderText = "Enter your system prompt here..."; - inputConfig.processInput = [¤tPreset](const std::string& input) { - // Copy the input to our temporary buffer first - tempSystemPrompt = input; - - // Then safely update the preset's system prompt + inputConfig.processInput = [¤tPreset, this](const std::string& input) { + // Update shared buffer and preset + m_sharedSystemPromptBuffer = input; currentPreset.systemPrompt = input; }; @@ -232,9 +265,14 @@ class SamplingSettingsComponent { void setSystemPromptLabel(const LabelConfig& label) { m_systemPromptLabel = label; } void setModelSettingsLabel(const LabelConfig& label) { m_modelSettingsLabel = label; } + // Callback used to request opening the system prompt modal + std::function m_onEditSystemPromptRequested; + private: float& m_sidebarWidth; bool& m_focusSystemPrompt; + std::string& m_sharedSystemPromptBuffer; + LabelConfig m_systemPromptLabel{ "##systempromptlabel", "System Prompt", @@ -368,17 +406,27 @@ class ModelPresetSidebar { ModelPresetSidebar() : m_sidebarWidth(Config::ChatHistorySidebar::SIDEBAR_WIDTH), m_presetSelectionComponent(m_sidebarWidth), - m_samplingSettingsComponent(m_sidebarWidth, m_focusSystemPrompt), + m_samplingSettingsComponent(m_sidebarWidth, m_focusSystemPrompt, m_sharedSystemPromptBuffer), m_saveAsDialogComponent(m_sidebarWidth, m_focusNewPresetName), + m_systemPromptModalComponent(m_sidebarWidth), m_exportButtonComponent(m_sidebarWidth) { + // Initialize the system prompt buffer with sufficient capacity + m_sharedSystemPromptBuffer.reserve(Config::InputField::TEXT_SIZE); + // Set up the callback for "Save as New" so that it shows the modal. m_presetSelectionComponent.m_onSaveAsRequested = [this]() { m_showSaveAsDialog = true; }; + + // Set up the callback for "Edit System Prompt" button + m_samplingSettingsComponent.m_onEditSystemPromptRequested = [this]() { + m_showSystemPromptModal = true; + m_focusModalEditor = true; // Focus editor when opening the modal + }; } void render() { ImGuiIO& io = ImGui::GetIO(); - const float sidebarHeight = io.DisplaySize.y - Config::TITLE_BAR_HEIGHT; + const float sidebarHeight = io.DisplaySize.y - Config::TITLE_BAR_HEIGHT - Config::FOOTER_HEIGHT; ImGui::SetNextWindowPos(ImVec2(io.DisplaySize.x - m_sidebarWidth, Config::TITLE_BAR_HEIGHT), ImGuiCond_Always); ImGui::SetNextWindowSize(ImVec2(m_sidebarWidth, sidebarHeight), ImGuiCond_Always); @@ -405,21 +453,26 @@ class ModelPresetSidebar { ImGui::End(); m_saveAsDialogComponent.render(m_showSaveAsDialog, m_newPresetName); + m_systemPromptModalComponent.render(m_showSystemPromptModal, m_sharedSystemPromptBuffer, m_focusModalEditor); } - float getSidebarWidth() const { - return m_sidebarWidth; - } + float getSidebarWidth() const { + return m_sidebarWidth; + } private: float m_sidebarWidth; bool m_showSaveAsDialog = false; + bool m_showSystemPromptModal = false; std::string m_newPresetName; + std::string m_sharedSystemPromptBuffer; bool m_focusSystemPrompt = true; + bool m_focusModalEditor = true; bool m_focusNewPresetName = true; PresetSelectionComponent m_presetSelectionComponent; SamplingSettingsComponent m_samplingSettingsComponent; SaveAsDialogComponent m_saveAsDialogComponent; + SystemPromptModalComponent m_systemPromptModalComponent; ExportButtonComponent m_exportButtonComponent; }; \ No newline at end of file diff --git a/include/ui/chat/system_prompt_modal.hpp b/include/ui/chat/system_prompt_modal.hpp new file mode 100644 index 0000000..1b4de04 --- /dev/null +++ b/include/ui/chat/system_prompt_modal.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include "imgui.h" +#include "model/preset_manager.hpp" +#include "ui/widgets.hpp" +#include "config.hpp" +#include + +class SystemPromptModalComponent { +public: + SystemPromptModalComponent(float& sidebarWidth) + : m_sidebarWidth(sidebarWidth) + { + } + + void render(bool& showDialog, std::string& sharedSystemPromptBuffer, bool& focusEditor) { + // Always render the modal so that it stays open if already open. + ModalConfig config{ + "Edit System Prompt", // Title + "System Prompt Editor", // Identifier + ImVec2(600, 400), // Larger size for text editing + [&]() { + // Get the current preset + auto currentPresetOpt = Model::PresetManager::getInstance().getCurrentPreset(); + if (!currentPresetOpt) return; + + // Create a multiline text input for the system prompt using the shared buffer + InputFieldConfig inputConfig( + "##systempromptmodal", + ImVec2(ImGui::GetWindowSize().x - 32.0f, ImGui::GetWindowSize().y - 64), + sharedSystemPromptBuffer, + focusEditor + ); + inputConfig.placeholderText = "Enter your system prompt here..."; + inputConfig.flags = ImGuiInputTextFlags_AllowTabInput; + inputConfig.processInput = [&](const std::string& input) { + // Update shared buffer and preset directly + sharedSystemPromptBuffer = input; + if (currentPresetOpt) { + currentPresetOpt->get().systemPrompt = input; + } + }; + + // Render the multiline input + InputField::renderMultiline(inputConfig); + }, + showDialog // Initial open flag passed in. + }; + + // Set modal padding + config.padding = ImVec2(16.0f, 8.0f); + ModalWindow::render(config); + + // If the popup is no longer open, ensure showDialog remains false. + if (!ImGui::IsPopupOpen(config.id.c_str())) + showDialog = false; + } + +private: + float& m_sidebarWidth; +}; \ No newline at end of file diff --git a/include/ui/markdown.hpp b/include/ui/markdown.hpp index 12fafc5..47db23e 100644 --- a/include/ui/markdown.hpp +++ b/include/ui/markdown.hpp @@ -4,6 +4,15 @@ #include #include +#ifdef _WIN32 +#include +#include +#elif defined(__APPLE__) +#include +#else // Linux and other Unix-like systems +#include +#endif + #include "ui/widgets.hpp" #include "config.hpp" @@ -84,6 +93,26 @@ class MarkdownRenderer : public imgui_md return true; } + void open_url() const override + { + // Get the URL from the base class's m_href member + const std::string url = m_href; + + if (url.empty()) { + return; // No URL to open + } + +#ifdef _WIN32 + ShellExecuteA(NULL, "open", url.c_str(), NULL, NULL, SW_SHOWNORMAL); +#elif defined(__APPLE__) + std::string cmd = "open \"" + url + "\""; + system(cmd.c_str()); +#else + std::string cmd = "xdg-open \"" + url + "\""; + system(cmd.c_str()); +#endif + } + void html_div(const std::string& dclass, bool enter) override { // Example toggling text color if
... diff --git a/include/ui/model_manager_modal.hpp b/include/ui/model_manager_modal.hpp new file mode 100644 index 0000000..819cd28 --- /dev/null +++ b/include/ui/model_manager_modal.hpp @@ -0,0 +1,1658 @@ +#pragma once + +#include "imgui.h" +#include "ui/widgets.hpp" +#include "ui/markdown.hpp" +#include "model/model_manager.hpp" +#include "model/gguf_reader.hpp" +#include "ui/fonts.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ModelManagerConstants { + constexpr float cardWidth = 200.0f; + constexpr float cardHeight = 220.0f; + constexpr float cardSpacing = 10.0f; + constexpr float padding = 16.0f; + constexpr float modalVerticalScale = 0.9f; + constexpr float sectionSpacing = 20.0f; + constexpr float sectionHeaderHeight = 30.0f; +} + +class AddCustomModelModalComponent { +public: + AddCustomModelModalComponent() { + // Add variant button + ButtonConfig addVariantButton; + addVariantButton.id = "##confirmAddVariant"; + addVariantButton.backgroundColor = RGBAToImVec4(26, 95, 180, 255); + addVariantButton.hoverColor = RGBAToImVec4(53, 132, 228, 255); + addVariantButton.activeColor = RGBAToImVec4(26, 95, 180, 255); + addVariantButton.size = ImVec2(130, 0); + addVariantButton.onClick = [this]() { + if (validateVariantForm()) { + Model::ModelVariant variant; + variant.type = m_currentVariantName; + + // Determine if input is URL or local path + bool isUrl = isUrlInput(m_currentVariantPath); + + if (isUrl) { + // If it's a URL, set downloadLink and generate a local path + variant.downloadLink = m_currentVariantPath; + + // Extract filename from URL + std::string filename = getFilenameFromPath(m_currentVariantPath); + + // Generate path in format: models/// + variant.path = "models/" + m_modelName + "/" + m_currentVariantName + "/" + filename; + + // For URL, mark as not downloaded yet + variant.isDownloaded = false; + variant.downloadProgress = 0.0; + } + else { + // If it's a local path, set path and leave downloadLink empty + variant.path = m_currentVariantPath; + variant.downloadLink = ""; // Empty for local files + variant.isDownloaded = true; // Already available locally + variant.downloadProgress = 100.0; + } + + variant.lastSelected = 0; + variant.size = getFileSize(m_currentVariantPath, isUrl); + + // Check if we're editing or adding a new variant + if (!m_editingVariantName.empty()) { + // If the name changed, remove the old entry + if (m_editingVariantName != m_currentVariantName) { + m_variants.erase(m_editingVariantName); + } + // Add with new or same name + m_variants[m_currentVariantName] = variant; + m_editingVariantName.clear(); // Clear edit mode + } + else { + // Add new variant + m_variants[m_currentVariantName] = variant; + } + + // Clear the form and collapse it + m_currentVariantName.clear(); + m_currentVariantPath.clear(); + m_variantErrorMessage.clear(); + m_showVariantForm = false; + + // Reset focus for next time the variant form opens + s_focusVariantName = true; + s_focusVariantPath = false; + } + }; + + variantButtons.push_back(addVariantButton); + } + + void render(bool& openModal) { + // Reset model added flag when modal opens + if (openModal && !m_wasOpen) { + m_modelAdded = false; + s_focusAuthor = true; + s_focusModelName = false; + s_focusVariantName = false; + s_focusVariantPath = false; + } + + m_wasOpen = openModal; + + ModalConfig config{ + "Add Custom Model", + "Add Custom Model", + ImVec2(500, 550), + [this]() { + ImGui::PushStyleColor(ImGuiCol_ScrollbarBg, ImVec4(0, 0, 0, 0)); + ImGui::BeginChild("##addCustomModelChild", ImVec2(0, + ImGui::GetContentRegionAvail().y - 42), false); + renderMainForm(); + ImGui::EndChild(); + ImGui::PopStyleColor(); + + ButtonConfig submitButton; + submitButton.id = "##submitAddCustomModel"; + submitButton.label = "Submit"; + submitButton.backgroundColor = RGBAToImVec4(26, 95, 180, 255); + submitButton.hoverColor = RGBAToImVec4(53, 132, 228, 255); + submitButton.activeColor = RGBAToImVec4(26, 95, 180, 255); + submitButton.size = ImVec2(ImGui::GetContentRegionAvail().x - 12.0F, 0); + submitButton.onClick = [this]() { + if (validateMainForm()) { + if (submitCustomModel()) { + m_modelAdded = true; + ImGui::CloseCurrentPopup(); + } + else { + m_errorMessage = "Failed to add custom model. Check the model file and try again."; + } + } + }; + + if (m_variants.empty()) { + submitButton.state = ButtonState::DISABLED; + } + + ImGui::SetCursorPos(ImVec2( + ImGui::GetCursorPosX() + 6.0F, + ImGui::GetCursorPosY() + ImGui::GetContentRegionAvail().y - 30.0F + )); + Button::render(submitButton); + }, + openModal + }; + config.padding = ImVec2(16.0f, 16.0f); + ModalWindow::render(config); + + if (!ImGui::IsPopupOpen(config.id.c_str())) { + openModal = false; + if (!m_modelAdded) { + clearForm(); + } + } + } + + // Check if a model was successfully added in the last session + bool wasModelAdded() const { + return m_modelAdded; + } + + // Reset the model added flag after handling it + void resetModelAddedFlag() { + m_modelAdded = false; + } + +private: + // Main form data + std::string m_authorName; + std::string m_modelName; + std::map m_variants; + std::string m_errorMessage; + bool m_wasOpen = false; + bool m_modelAdded = false; + + // Variant form data + bool m_showVariantForm = false; + std::string m_currentVariantName; + std::string m_currentVariantPath; + std::string m_variantErrorMessage; + std::string m_editingVariantName; + + // Static focus control variables + static bool s_focusAuthor; + static bool s_focusModelName; + static bool s_focusVariantName; + static bool s_focusVariantPath; + + // Static counter for unique IDs + static int s_idCounter; + + // Buttons + std::vector variantButtons; + + // GGUF reader + GGUFMetadataReader m_ggufReader; + + // Check if input is a URL + bool isUrlInput(const std::string& input) { + // Simple regex to detect URLs + static const std::regex urlPattern( + R"(^(https?|ftp)://)" // protocol + R"([^\s/$.?#].[^\s]*$)", // domain and path + std::regex::icase + ); + + return std::regex_match(input, urlPattern); + } + + // Extract filename from path or URL + std::string getFilenameFromPath(const std::string& path) { + // First try to use filesystem for local paths + std::string filename; + + try { + // For URLs, extract the last part of the path + if (isUrlInput(path)) { + // Find the last '/' character + size_t lastSlash = path.find_last_of('/'); + if (lastSlash != std::string::npos && lastSlash < path.length() - 1) { + filename = path.substr(lastSlash + 1); + + // Handle URL query parameters + size_t queryPos = filename.find('?'); + if (queryPos != std::string::npos) { + filename = filename.substr(0, queryPos); + } + } + else { + // Fallback + filename = "model.gguf"; + } + } + else { + // For local paths, use std::filesystem + std::filesystem::path fsPath(path); + filename = fsPath.filename().string(); + } + } + catch (...) { + // Last resort fallback + filename = "model.gguf"; + } + + // Ensure the filename has .gguf extension + if (filename.length() < 5 || + filename.substr(filename.length() - 5) != ".gguf") { + filename += ".gguf"; + } + + return filename; + } + + // CURL callback for URL file size check + static size_t headerCallback(char* buffer, size_t size, size_t nitems, void* userdata) { + size_t totalSize = size * nitems; + std::string header(buffer, totalSize); + + // Convert header to lowercase for case-insensitive comparison + std::transform(header.begin(), header.end(), header.begin(), + [](unsigned char c) { return std::tolower(c); }); + + // Check if this is the content-length header + if (header.find("content-length:") == 0) { + // Extract the size value + std::string lengthStr = header.substr(15); // Skip "content-length:" + // Trim whitespace + lengthStr.erase(0, lengthStr.find_first_not_of(" \t\r\n")); + lengthStr.erase(lengthStr.find_last_not_of(" \t\r\n") + 1); + + // Store the size in the userdata pointer + if (!lengthStr.empty()) { + try { + *(size_t*)userdata = std::stoull(lengthStr); + } + catch (...) { + // Conversion error, ignore + } + } + } + + return totalSize; + } + + // Get file size in GB from URL using a HEAD request + float getUrlFileSize(const std::string& url) { + size_t fileSizeBytes = 0; + + CURL* curl = curl_easy_init(); + if (curl) { + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // HEAD request + curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, headerCallback); + curl_easy_setopt(curl, CURLOPT_HEADERDATA, &fileSizeBytes); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); // Follow redirects + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 10L); // 10 second timeout + + CURLcode res = curl_easy_perform(curl); + curl_easy_cleanup(curl); + + if (res != CURLE_OK) { + // Failed to get size, return 0 + return 0.0f; + } + } + + // Convert bytes to GB + float fileSizeGB = static_cast(fileSizeBytes) / (1024.0f * 1024.0f * 1024.0f); + return fileSizeGB; + } + + // Get file size in GB from local path + float getLocalFileSize(const std::string& path) { + try { + std::filesystem::path fsPath(path); + if (std::filesystem::exists(fsPath) && std::filesystem::is_regular_file(fsPath)) { + // Get size in bytes and convert to GB + uintmax_t sizeInBytes = std::filesystem::file_size(fsPath); + float sizeInGB = static_cast(sizeInBytes) / (1024.0f * 1024.0f * 1024.0f); + return sizeInGB; + } + } + catch (...) { + // If there's any error, return 0 + } + return 0.0f; + } + + // Get file size in GB from either URL or local path + float getFileSize(const std::string& path, bool isUrl) { + if (isUrl) { + return getUrlFileSize(path); + } + else { + return getLocalFileSize(path); + } + } + + // Generate a unique ID for UI elements + std::string generateUniqueId(const std::string& prefix) { + return prefix + std::to_string(s_idCounter++); + } + + // Start editing a variant + void startEditingVariant(const std::string& variantName) { + m_editingVariantName = variantName; + m_currentVariantName = variantName; + + const auto& variant = m_variants[variantName]; + // Use downloadLink if available, otherwise use path + m_currentVariantPath = !variant.downloadLink.empty() + ? variant.downloadLink + : variant.path; + + m_showVariantForm = true; + s_focusVariantName = true; + } + + void renderMainForm() { + // Display error message if any + if (!m_errorMessage.empty()) { + LabelConfig errorLabel; + errorLabel.id = "##mainErrorMessage"; + errorLabel.label = m_errorMessage; + errorLabel.size = ImVec2(0, 0); + errorLabel.fontType = FontsManager::ITALIC; + errorLabel.fontSize = FontsManager::SM; + errorLabel.color = ImVec4(1.0f, 0.3f, 0.3f, 1.0f); + errorLabel.alignment = Alignment::LEFT; + Label::render(errorLabel); + ImGui::Spacing(); + } + + // Author input + LabelConfig authorLabel; + authorLabel.id = "##modelAuthorLabel"; + authorLabel.label = "Author"; + authorLabel.size = ImVec2(0, 0); + authorLabel.fontType = FontsManager::REGULAR; + authorLabel.fontSize = FontsManager::MD; + authorLabel.alignment = Alignment::LEFT; + Label::render(authorLabel); + + InputFieldConfig authorFieldConfig( + "##modelAuthorInput", + ImVec2(ImGui::GetContentRegionAvail().x - 12.0F, 32.0f), + m_authorName, + s_focusAuthor + ); + // Reset focus flag after use + s_focusAuthor = false; + + authorFieldConfig.placeholderText = "Enter author name"; + authorFieldConfig.backgroundColor = RGBAToImVec4(34, 34, 34, 255); + authorFieldConfig.hoverColor = RGBAToImVec4(44, 44, 44, 255); + authorFieldConfig.activeColor = RGBAToImVec4(54, 54, 54, 255); + InputField::render(authorFieldConfig); + ImGui::Spacing(); + ImGui::Spacing(); + + // Model name input + LabelConfig nameLabel; + nameLabel.id = "##modelNameLabel"; + nameLabel.label = "Model Name"; + nameLabel.size = ImVec2(0, 0); + nameLabel.fontType = FontsManager::REGULAR; + nameLabel.fontSize = FontsManager::MD; + nameLabel.alignment = Alignment::LEFT; + Label::render(nameLabel); + + InputFieldConfig nameFieldConfig( + "##modelNameInput", + ImVec2(ImGui::GetContentRegionAvail().x - 12.0F, 32.0f), + m_modelName, + s_focusModelName + ); + // Reset focus flag after use + s_focusModelName = false; + + nameFieldConfig.placeholderText = "Enter model name"; + nameFieldConfig.backgroundColor = RGBAToImVec4(34, 34, 34, 255); + nameFieldConfig.hoverColor = RGBAToImVec4(44, 44, 44, 255); + nameFieldConfig.activeColor = RGBAToImVec4(54, 54, 54, 255); + InputField::render(nameFieldConfig); + ImGui::Spacing(); + ImGui::Spacing(); + + // Variants section + LabelConfig variantsLabel; + variantsLabel.id = "##modelVariantsLabel"; + variantsLabel.label = "Variants:"; + variantsLabel.size = ImVec2(0, 0); + variantsLabel.fontType = FontsManager::REGULAR; + variantsLabel.fontSize = FontsManager::MD; + variantsLabel.alignment = Alignment::LEFT; + Label::render(variantsLabel); + ImGui::Spacing(); + + // Display existing variants in a scrollable area + if (!m_variants.empty()) { + ImGui::PushStyleColor(ImGuiCol_ChildBg, RGBAToImVec4(26, 26, 26, 255)); + ImGui::PushStyleVar(ImGuiStyleVar_ChildRounding, 8.0f); + ImGui::BeginChild("##variantsList", ImVec2(ImGui::GetContentRegionAvail().x, 180), true); + + int variantIdx = 0; + for (auto& [variantName, variant] : m_variants) { + std::string variantId = "variant_" + std::to_string(variantIdx++); + ImGui::PushID(variantId.c_str()); + + // Variant section with border + ImGui::BeginGroup(); + ImGui::PushStyleColor(ImGuiCol_ChildBg, RGBAToImVec4(34, 34, 34, 255)); + ImGui::PushStyleVar(ImGuiStyleVar_ChildRounding, 4.0f); + ImGui::BeginChild(("##variantItem_" + variantId).c_str(), ImVec2(ImGui::GetContentRegionAvail().x, 100), true); + + // Variant name - bold + LabelConfig variantNameLabel; + variantNameLabel.id = "##variant_name_" + variantId; + variantNameLabel.label = "Variant: " + variantName; + variantNameLabel.fontType = FontsManager::BOLD; + variantNameLabel.fontSize = FontsManager::MD; + Label::render(variantNameLabel); + + // Show path or URL info + std::string locationInfo; + if (!variant.downloadLink.empty()) { + locationInfo = "URL: " + variant.downloadLink; + + // Also show the local path where it will be downloaded + LabelConfig variantPathLabel; + variantPathLabel.id = "##variant_download_path_" + variantId; + variantPathLabel.label = "Download path: " + variant.path; + variantPathLabel.fontType = FontsManager::ITALIC; + variantPathLabel.fontSize = FontsManager::SM; + Label::render(variantPathLabel); + } + else { + locationInfo = "Path: " + variant.path; + } + + // Display the path/URL info + LabelConfig variantPathLabel; + variantPathLabel.id = "##variant_path_" + variantId; + variantPathLabel.label = locationInfo; + variantPathLabel.fontType = FontsManager::REGULAR; + variantPathLabel.fontSize = FontsManager::SM; + Label::render(variantPathLabel); + + // Edit button - at right side + ImGui::SetCursorPos(ImVec2(ImGui::GetContentRegionAvail().x - 48, 10)); + ButtonConfig editVariantBtn; + editVariantBtn.id = "##editVariant_" + variantId; + editVariantBtn.icon = ICON_CI_EDIT; + editVariantBtn.size = ImVec2(24, 24); + editVariantBtn.tooltip = "Edit variant"; + editVariantBtn.onClick = [this, variantName]() { + startEditingVariant(variantName); + }; + Button::render(editVariantBtn); + + // Delete button - small, at the right side + ImGui::SetCursorPos(ImVec2(ImGui::GetContentRegionAvail().x - 18, 10)); + ButtonConfig deleteVariantBtn; + deleteVariantBtn.id = "##deleteVariant_" + variantId; + deleteVariantBtn.icon = ICON_CI_TRASH; + deleteVariantBtn.hoverColor = RGBAToImVec4(220, 70, 70, 255); + deleteVariantBtn.size = ImVec2(24, 24); + deleteVariantBtn.tooltip = "Delete variant"; + deleteVariantBtn.onClick = [this, variantName]() { + // If we're currently editing this variant, cancel editing + if (m_editingVariantName == variantName) { + m_editingVariantName.clear(); + m_currentVariantName.clear(); + m_currentVariantPath.clear(); + m_showVariantForm = false; + } + m_variants.erase(variantName); + }; + Button::render(deleteVariantBtn); + + ImGui::EndChild(); + ImGui::PopStyleVar(); + ImGui::PopStyleColor(); + ImGui::EndGroup(); + + ImGui::PopID(); + ImGui::Spacing(); + } + + ImGui::EndChild(); + ImGui::PopStyleVar(); + ImGui::PopStyleColor(); + } + else { + LabelConfig noVariantsLabel; + noVariantsLabel.id = "##noVariants"; + noVariantsLabel.label = "No variants added. Click 'Add New Variant' button below."; + noVariantsLabel.fontType = FontsManager::ITALIC; + noVariantsLabel.fontSize = FontsManager::SM; + noVariantsLabel.color = ImVec4(0.7f, 0.7f, 0.7f, 1.0f); + Label::render(noVariantsLabel); + ImGui::Spacing(); + } + + // Collapsible "Add New Variant" section + ImGui::Spacing(); + + // Determine button label based on whether we're editing or adding + std::string buttonLabel = "Add New Variant"; + if (m_showVariantForm) { + buttonLabel = m_editingVariantName.empty() ? "Cancel Adding Variant" : "Cancel Editing Variant"; + } + + ButtonConfig toggleVariantFormButton; + toggleVariantFormButton.id = "##toggleAddNewVariant"; + toggleVariantFormButton.label = buttonLabel; + toggleVariantFormButton.icon = m_showVariantForm ? ICON_CI_CLOSE : ICON_CI_PLUS; + toggleVariantFormButton.alignment = Alignment::LEFT; + toggleVariantFormButton.size = ImVec2( + ImGui::CalcTextSize(buttonLabel.c_str()).x + /*padding + icon size*/ 40.0f, 32.0f); + toggleVariantFormButton.onClick = [this]() { + if (m_showVariantForm) { + // Cancel editing/adding + m_showVariantForm = false; + m_currentVariantName.clear(); + m_currentVariantPath.clear(); + m_variantErrorMessage.clear(); + m_editingVariantName.clear(); + } + else { + // Start adding new + m_showVariantForm = true; + m_editingVariantName.clear(); + m_currentVariantName.clear(); + m_currentVariantPath.clear(); + s_focusVariantName = true; + s_focusVariantPath = false; + } + }; + Button::render(toggleVariantFormButton); + + // Render the collapsible variant form if it's visible + if (m_showVariantForm) { + ImGui::Spacing(); + + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 10.0F); + + ImGui::PushStyleColor(ImGuiCol_ChildBg, RGBAToImVec4(30, 30, 30, 255)); + ImGui::PushStyleVar(ImGuiStyleVar_ChildRounding, 5.0f); + ImGui::BeginChild("##variantFormSection", ImVec2(ImGui::GetContentRegionAvail().x, 256), true); + + // Title - changes based on whether we're editing or adding + LabelConfig variantFormLabel; + variantFormLabel.id = "##addVariantTitle"; + variantFormLabel.label = m_editingVariantName.empty() ? "Add New Variant" : "Edit Variant"; + variantFormLabel.fontType = FontsManager::BOLD; + variantFormLabel.fontSize = FontsManager::MD; + variantFormLabel.alignment = Alignment::LEFT; + Label::render(variantFormLabel); + ImGui::Spacing(); + + // Display error message if any + if (!m_variantErrorMessage.empty()) { + LabelConfig errorLabel; + errorLabel.id = "##variantErrorMessage"; + errorLabel.label = m_variantErrorMessage; + errorLabel.size = ImVec2(0, 0); + errorLabel.fontType = FontsManager::ITALIC; + errorLabel.fontSize = FontsManager::SM; + errorLabel.color = ImVec4(1.0f, 0.3f, 0.3f, 1.0f); + errorLabel.alignment = Alignment::LEFT; + Label::render(errorLabel); + ImGui::Spacing(); + } + + // Variant Name + LabelConfig variantNameLabel; + variantNameLabel.id = "##variantNameLabel"; + variantNameLabel.label = "Variant Name"; + variantNameLabel.size = ImVec2(0, 0); + variantNameLabel.fontType = FontsManager::REGULAR; + variantNameLabel.fontSize = FontsManager::MD; + variantNameLabel.alignment = Alignment::LEFT; + Label::render(variantNameLabel); + + InputFieldConfig variantNameField( + "##variantNameInput", + ImVec2(ImGui::GetContentRegionAvail().x, 32.0f), + m_currentVariantName, + s_focusVariantName + ); + // Reset focus flag after use + s_focusVariantName = false; + + variantNameField.placeholderText = "e.g., q4_0, f16, etc."; + variantNameField.backgroundColor = RGBAToImVec4(34, 34, 34, 255); + variantNameField.hoverColor = RGBAToImVec4(44, 44, 44, 255); + variantNameField.activeColor = RGBAToImVec4(54, 54, 54, 255); + InputField::render(variantNameField); + ImGui::Spacing(); + + // Path / URL + LabelConfig variantPathLabel; + variantPathLabel.id = "##variantPathLabel"; + variantPathLabel.label = "Path / URL to GGUF"; + variantPathLabel.size = ImVec2(0, 0); + variantPathLabel.fontType = FontsManager::REGULAR; + variantPathLabel.fontSize = FontsManager::MD; + variantPathLabel.alignment = Alignment::LEFT; + Label::render(variantPathLabel); + + // Add info about URL vs path handling + LabelConfig pathInfoLabel; + pathInfoLabel.id = "##pathInfoLabel"; + pathInfoLabel.label = "Enter a URL (https://) to download or a local file path"; + pathInfoLabel.fontType = FontsManager::ITALIC; + pathInfoLabel.fontSize = FontsManager::SM; + pathInfoLabel.color = ImVec4(0.7f, 0.7f, 0.7f, 1.0f); + Label::render(pathInfoLabel); + + InputFieldConfig variantPathField( + "##variantPathInput", + ImVec2(ImGui::GetContentRegionAvail().x - 48, 32.0f), + m_currentVariantPath, + s_focusVariantPath + ); + // Reset focus flag after use + s_focusVariantPath = false; + + variantPathField.placeholderText = "Enter path or URL to the model file"; + variantPathField.backgroundColor = RGBAToImVec4(34, 34, 34, 255); + variantPathField.hoverColor = RGBAToImVec4(44, 44, 44, 255); + variantPathField.activeColor = RGBAToImVec4(54, 54, 54, 255); + InputField::render(variantPathField); + + ImGui::SameLine(); + ButtonConfig browseButton; + browseButton.id = "##browseVariantPath"; + browseButton.icon = ICON_CI_FOLDER; + browseButton.size = ImVec2(38, 38); + browseButton.onClick = [this]() { + openFileDialog(); + }; + Button::render(browseButton); + + ImGui::Spacing(); + + // Update Add/Update variant button + ButtonConfig actionButton = variantButtons[0]; // Get our base Add Variant button + actionButton.id = "##" + std::string(m_editingVariantName.empty() ? "addVariant" : "updateVariant"); + actionButton.label = m_editingVariantName.empty() ? "Add Variant" : "Update Variant"; + actionButton.size = ImVec2(ImGui::GetContentRegionAvail().x, 0); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 16.0F); + Button::render(actionButton); + + ImGui::EndChild(); + ImGui::PopStyleVar(); + ImGui::PopStyleColor(); + } + } + + void openFileDialog() { + nfdu8char_t* outPath = nullptr; + nfdu8filteritem_t filters[1] = { {"GGUF Models", "gguf"} }; + + nfdopendialogu8args_t args{}; + args.filterList = filters; + args.filterCount = 1; + + nfdresult_t result = NFD_OpenDialogU8_With(&outPath, &args); + + if (result == NFD_OKAY) { + m_currentVariantPath = (const char*)outPath; + NFD_FreePathU8(outPath); + s_focusVariantPath = true; + } + else if (result == NFD_ERROR) { + m_variantErrorMessage = "Error opening file dialog: "; + m_variantErrorMessage += NFD_GetError(); + } + } + + bool validateMainForm() { + m_errorMessage.clear(); + + if (m_authorName.empty()) { + m_errorMessage = "Error: Author name cannot be empty"; + s_focusAuthor = true; + return false; + } + + if (m_modelName.empty()) { + m_errorMessage = "Error: Model name cannot be empty"; + s_focusModelName = true; + return false; + } + + if (m_variants.empty()) { + m_errorMessage = "Error: You must add at least one variant"; + return false; + } + + return true; + } + + bool validateVariantForm() { + m_variantErrorMessage.clear(); + + if (m_currentVariantName.empty()) { + m_variantErrorMessage = "Error: Variant name cannot be empty"; + s_focusVariantName = true; + return false; + } + + if (m_currentVariantPath.empty()) { + m_variantErrorMessage = "Error: Path/URL cannot be empty"; + s_focusVariantPath = true; + return false; + } + + // Check if variant already exists (only if we're adding a new one or changing the name) + if (m_currentVariantName != m_editingVariantName && + m_variants.find(m_currentVariantName) != m_variants.end()) { + m_variantErrorMessage = "Error: A variant with this name already exists"; + s_focusVariantName = true; + return false; + } + + return true; + } + + bool submitCustomModel() { + // Create a new ModelData instance + Model::ModelData modelData; + modelData.name = m_modelName; + modelData.author = m_authorName; + modelData.variants = m_variants; + + std::optional metadata; + for (const auto& [variantName, variant] : m_variants) { + if (!variant.downloadLink.empty()) { + metadata = m_ggufReader.readModelParams(variant.downloadLink, false); + break; + } + else { + metadata = m_ggufReader.readModelParams(variant.path, false); + break; + } + } + + if (!metadata.has_value()) { + m_errorMessage = "Error: Failed to read model metadata"; + return false; + } + + modelData.hidden_size = metadata->hidden_size; + modelData.attention_heads = metadata->attention_heads; + modelData.hidden_layers = metadata->hidden_layers; + modelData.kv_heads = metadata->kv_heads; + + // Call ModelManager to add the custom model + if (!Model::ModelManager::getInstance().addCustomModel(modelData)) { + m_errorMessage = "Error: Failed to add custom model. The model may already exist."; + return false; + } + + // Clear the form + clearForm(); + + return true; + } + + void clearForm() { + m_authorName.clear(); + m_modelName.clear(); + m_variants.clear(); + m_errorMessage.clear(); + m_showVariantForm = false; + m_currentVariantName.clear(); + m_currentVariantPath.clear(); + m_variantErrorMessage.clear(); + m_editingVariantName.clear(); + + // Reset focus flags + s_focusAuthor = true; + s_focusModelName = false; + s_focusVariantName = false; + s_focusVariantPath = false; + } +}; + +// Initialize static members +bool AddCustomModelModalComponent::s_focusAuthor = true; +bool AddCustomModelModalComponent::s_focusModelName = false; +bool AddCustomModelModalComponent::s_focusVariantName = false; +bool AddCustomModelModalComponent::s_focusVariantPath = false; +int AddCustomModelModalComponent::s_idCounter = 0; + +class DeleteModelModalComponent { +public: + DeleteModelModalComponent() { + ButtonConfig cancelButton; + cancelButton.id = "##cancelDeleteModel"; + cancelButton.label = "Cancel"; + cancelButton.backgroundColor = RGBAToImVec4(34, 34, 34, 255); + cancelButton.hoverColor = RGBAToImVec4(53, 132, 228, 255); + cancelButton.activeColor = RGBAToImVec4(26, 95, 180, 255); + cancelButton.textColor = RGBAToImVec4(255, 255, 255, 255); + cancelButton.size = ImVec2(130, 0); + cancelButton.onClick = []() { ImGui::CloseCurrentPopup(); }; + + ButtonConfig confirmButton; + confirmButton.id = "##confirmDeleteModel"; + confirmButton.label = "Confirm"; + confirmButton.backgroundColor = RGBAToImVec4(26, 95, 180, 255); + confirmButton.hoverColor = RGBAToImVec4(53, 132, 228, 255); + confirmButton.activeColor = RGBAToImVec4(26, 95, 180, 255); + confirmButton.size = ImVec2(130, 0); + confirmButton.onClick = [this]() { + if (m_index != -1 && !m_variant.empty()) { + Model::ModelManager::getInstance().deleteDownloadedModel(m_index, m_variant); + ImGui::CloseCurrentPopup(); + } + }; + + buttons.push_back(cancelButton); + buttons.push_back(confirmButton); + } + + void setModel(int index, const std::string& variant) { + m_index = index; + m_variant = variant; + } + + void render(bool& openModal) { + if (m_index == -1 || m_variant.empty()) { + openModal = false; + return; + } + + ModalConfig config{ + "Confirm Delete Model", + "Confirm Delete Model", + ImVec2(300, 96), + [this]() { + Button::renderGroup(buttons, 16, ImGui::GetCursorPosY() + 8); + }, + openModal + }; + config.padding = ImVec2(16.0f, 8.0f); + ModalWindow::render(config); + + if (!ImGui::IsPopupOpen(config.id.c_str())) { + openModal = false; + m_index = -1; + m_variant.clear(); + } + } + +private: + int m_index = -1; + std::string m_variant; + std::vector buttons; +}; + +class ModelCardRenderer { +public: + ModelCardRenderer(int index, const Model::ModelData& modelData, + std::function onDeleteRequested, std::string id = "", bool allowSwitching = true) + : m_index(index), m_model(modelData), m_onDeleteRequested(onDeleteRequested), m_id(id) + { + selectButton.id = "##select" + std::to_string(m_index) + m_id; + selectButton.size = ImVec2(ModelManagerConstants::cardWidth - 18, 0); + + deleteButton.id = "##delete" + std::to_string(m_index) + m_id; + deleteButton.size = ImVec2(24, 0); + deleteButton.backgroundColor = RGBAToImVec4(200, 50, 50, 255); + deleteButton.hoverColor = RGBAToImVec4(220, 70, 70, 255); + deleteButton.activeColor = RGBAToImVec4(200, 50, 50, 255); + deleteButton.icon = ICON_CI_TRASH; + deleteButton.onClick = [this]() { + std::string currentVariant = Model::ModelManager::getInstance().getCurrentVariantForModel(m_model.name); + m_onDeleteRequested(m_index, currentVariant); + }; + + authorLabel.id = "##modelAuthor" + std::to_string(m_index) + m_id; + authorLabel.label = m_model.author; + authorLabel.size = ImVec2(0, 0); + authorLabel.fontType = FontsManager::ITALIC; + authorLabel.fontSize = FontsManager::SM; + authorLabel.alignment = Alignment::LEFT; + + nameLabel.id = "##modelName" + std::to_string(m_index) + m_id; + nameLabel.label = m_model.name; + nameLabel.size = ImVec2(0, 0); + nameLabel.fontType = FontsManager::BOLD; + nameLabel.fontSize = FontsManager::MD; + nameLabel.alignment = Alignment::LEFT; + + m_allowSwitching = allowSwitching; + } + + void render() { + auto& manager = Model::ModelManager::getInstance(); + std::string currentVariant = manager.getCurrentVariantForModel(m_model.name); + + ImGui::BeginGroup(); + ImGui::PushStyleColor(ImGuiCol_ChildBg, RGBAToImVec4(26, 26, 26, 255)); + ImGui::PushStyleVar(ImGuiStyleVar_ChildRounding, 8.0f); + + std::string childName = "ModelCard" + std::to_string(m_index) + m_id; + ImGui::BeginChild(childName.c_str(), ImVec2(ModelManagerConstants::cardWidth, ModelManagerConstants::cardHeight), true); + + renderHeader(); + ImGui::Spacing(); + renderVariantOptions(currentVariant); + + ImGui::SetCursorPosY(ModelManagerConstants::cardHeight - 35); + + bool isSelected = (m_model.name == manager.getCurrentModelName() && + currentVariant == manager.getCurrentVariantType()); + bool isDownloaded = manager.isModelDownloaded(m_index, currentVariant); + + if (!isDownloaded) { + double progress = manager.getModelDownloadProgress(m_index, currentVariant); + if (progress > 0.0) { + selectButton.label = "Cancel"; + selectButton.backgroundColor = RGBAToImVec4(200, 50, 50, 255); + selectButton.hoverColor = RGBAToImVec4(220, 70, 70, 255); + selectButton.activeColor = RGBAToImVec4(200, 50, 50, 255); + selectButton.icon = ICON_CI_CLOSE; + selectButton.onClick = [this, currentVariant]() { + Model::ModelManager::getInstance().cancelDownload(m_index, currentVariant); + }; + + ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 12); + float fraction = static_cast(progress) / 100.0f; + ProgressBar::render(fraction, ImVec2(ModelManagerConstants::cardWidth - 18, 6)); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 4); + } + else { + selectButton.label = "Download"; + selectButton.backgroundColor = RGBAToImVec4(26, 95, 180, 255); + selectButton.hoverColor = RGBAToImVec4(53, 132, 228, 255); + selectButton.activeColor = RGBAToImVec4(26, 95, 180, 255); + selectButton.icon = ICON_CI_CLOUD_DOWNLOAD; + selectButton.borderSize = 1.0f; + selectButton.onClick = [this, currentVariant]() { + Model::ModelManager::getInstance().setPreferredVariant(m_model.name, currentVariant); + Model::ModelManager::getInstance().downloadModel(m_index, currentVariant); + }; + } + } + else { + bool isLoadingSelected = manager.isLoadInProgress() && m_model.name == manager.getCurrentOnLoadingModel(); + bool isUnloading = manager.isUnloadInProgress() && m_model.name == manager.getCurrentOnUnloadingModel(); + + // Configure button label and base state + if (isLoadingSelected || isUnloading) { + selectButton.label = isLoadingSelected ? "Loading Model..." : "Unloading Model..."; + selectButton.state = ButtonState::DISABLED; + selectButton.icon = ""; // Clear any existing icon + selectButton.borderSize = 0.0f; // Remove border + } + else { + if (m_allowSwitching) { + selectButton.label = isSelected ? "Selected" : "Select"; + } + else { + selectButton.label = manager.isModelLoaded(m_model.name) + ? "Unload" : "Load Model"; + } + } + + // Base styling (applies to all states) + selectButton.backgroundColor = RGBAToImVec4(34, 34, 34, 255); + + // Disabled state for non-selected loading + if (!isSelected && manager.isLoadInProgress()) { + selectButton.state = ButtonState::DISABLED; + } + + // Common properties + selectButton.onClick = [this, &manager]() { + if (m_allowSwitching) + { + std::string variant = manager.getCurrentVariantForModel(m_model.name); + manager.switchModel(m_model.name, variant); + } + else + { + if (manager.isModelLoaded(m_model.name)) + { + manager.unloadModel(m_model.name); + } + else + { + manager.loadModelIntoEngine(m_model.name); + } + } + }; + selectButton.size = ImVec2(ModelManagerConstants::cardWidth - 18 - 5 - 24, 0); + + // Selected state styling (only if not loading) + if (isSelected && !isLoadingSelected) { + selectButton.borderColor = RGBAToImVec4(172, 131, 255, 255 / 4); + selectButton.borderSize = 1.0f; + selectButton.state = ButtonState::NORMAL; + selectButton.tooltip = "Click to unload model from memory"; + selectButton.onClick = [this, &manager]() { + manager.unloadModel(m_model.name); + }; + } + + // Add progress bar if in loading-selected state + if (isLoadingSelected || isUnloading) { + ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 12); + ProgressBar::render(0, ImVec2(ModelManagerConstants::cardWidth - 18, 6)); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 4); + } + } + + Button::render(selectButton); + + if (isDownloaded) { + ImGui::SameLine(); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 2); + ImGui::SetCursorPosX(ImGui::GetCursorPosX() + ImGui::GetContentRegionAvail().x - 24 - 2); + + if (isSelected && manager.isLoadInProgress()) + deleteButton.state = ButtonState::DISABLED; + else + deleteButton.state = ButtonState::NORMAL; + + if (manager.isModelLoaded(m_model.name)) + { + deleteButton.icon = ICON_CI_ARROW_UP; + deleteButton.tooltip = "Click to unload model"; + deleteButton.onClick = [this, &manager]() { + std::cout << "[ModelManagerModal] Unloading model from delete button: " << m_model.name << "\n"; + manager.unloadModel(m_model.name); + }; + } + else + { + deleteButton.icon = ICON_CI_TRASH; + deleteButton.tooltip = "Click to delete model"; + deleteButton.onClick = [this, &manager]() { + std::string currentVariant = manager.getCurrentVariantForModel(m_model.name); + m_onDeleteRequested(m_index, currentVariant); + }; + } + + Button::render(deleteButton); + } + + ImGui::EndChild(); + if (ImGui::IsItemHovered() || isSelected) { + ImVec2 min = ImGui::GetItemRectMin(); + ImVec2 max = ImGui::GetItemRectMax(); + ImU32 borderColor = IM_COL32(172, 131, 255, 255 / 2); + ImGui::GetWindowDrawList()->AddRect(min, max, borderColor, 8.0f, 0, 1.0f); + } + + ImGui::PopStyleVar(); + ImGui::PopStyleColor(); + ImGui::EndGroup(); + } + +private: + int m_index; + std::string m_id; + const Model::ModelData& m_model; + std::function m_onDeleteRequested; + bool m_allowSwitching; + + void renderHeader() { + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 2); + Label::render(authorLabel); + + ImGui::SameLine(); + + float modelMemoryRequirement = 0; + float kvramMemoryRequirement = 0; + + Model::ModelManager& manager = Model::ModelManager::getInstance(); + bool hasEnoughMemory = manager.hasEnoughMemoryForModel(m_model.name, + modelMemoryRequirement, kvramMemoryRequirement); + + ButtonConfig memorySufficientButton; + memorySufficientButton.id = "##memorySufficient" + std::to_string(m_index) + m_id; + memorySufficientButton.icon = ICON_CI_PASS_FILLED; + memorySufficientButton.size = ImVec2(24, 0); + memorySufficientButton.tooltip = "Sufficient memory available\n\nmodel: " + + std::to_string(static_cast(modelMemoryRequirement)) + " MB\nkv cache: " + + std::to_string(static_cast(kvramMemoryRequirement)) + " MB"; + + if (!hasEnoughMemory) + { + memorySufficientButton.icon = ICON_CI_WARNING; + memorySufficientButton.tooltip = "Not enough memory available\n\nmodel: " + + std::to_string(static_cast(modelMemoryRequirement)) + " MB\nkv cache: " + + std::to_string(static_cast(kvramMemoryRequirement)) + " MB"; + } + + // place it to the top right corner of the card + ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 8); + ImGui::SetCursorPosX(ImGui::GetCursorPosX() + ImGui::GetContentRegionAvail().x - 24 - 2); + + Button::render(memorySufficientButton); + + Label::render(nameLabel); + } + + void renderVariantOptions(const std::string& currentVariant) { + LabelConfig variantLabel; + variantLabel.id = "##variantLabel" + std::to_string(m_index); + variantLabel.label = "Model Variants"; + variantLabel.size = ImVec2(0, 0); + variantLabel.fontType = FontsManager::REGULAR; + variantLabel.fontSize = FontsManager::SM; + variantLabel.alignment = Alignment::LEFT; + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 2); + Label::render(variantLabel); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 4); + + // Calculate the height for the scrollable area + // Card height minus header space minus button space at bottom + const float variantAreaHeight = 80.0f; + + // Create a scrollable child window for variants + ImGui::BeginChild(("##VariantScroll" + std::to_string(m_index)).c_str(), + ImVec2(ModelManagerConstants::cardWidth - 18, variantAreaHeight), + false); + + // Helper function to render a single variant option + auto renderVariant = [this, ¤tVariant](const std::string& variant) { + ButtonConfig btnConfig; + btnConfig.id = "##" + variant + std::to_string(m_index); + btnConfig.icon = (currentVariant == variant) ? ICON_CI_CHECK : ICON_CI_CLOSE; + btnConfig.textColor = (currentVariant != variant) ? RGBAToImVec4(34, 34, 34, 255) : ImVec4(1, 1, 1, 1); + btnConfig.fontSize = FontsManager::SM; + btnConfig.size = ImVec2(24, 0); + btnConfig.backgroundColor = RGBAToImVec4(34, 34, 34, 255); + btnConfig.onClick = [variant, this]() { + Model::ModelManager::getInstance().setPreferredVariant(m_model.name, variant); + }; + ImGui::SetCursorPosX(ImGui::GetCursorPosX() + 4); + Button::render(btnConfig); + + ImGui::SameLine(0.0f, 4.0f); + LabelConfig variantLabel; + variantLabel.id = "##" + variant + "Label" + std::to_string(m_index); + variantLabel.label = variant; + variantLabel.size = ImVec2(0, 0); + variantLabel.fontType = FontsManager::REGULAR; + variantLabel.fontSize = FontsManager::SM; + variantLabel.alignment = Alignment::LEFT; + ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 6); + Label::render(variantLabel); + }; + + // Iterate through all variants in the model + for (const auto& [variant, variantData] : m_model.variants) { + // For each variant, render a button + renderVariant(variant); + ImGui::Spacing(); + } + + // End the scrollable area + ImGui::EndChild(); + } + + ButtonConfig deleteButton; + ButtonConfig selectButton; + LabelConfig nameLabel; + LabelConfig authorLabel; +}; + +struct SortableModel { + int index; + std::string name; + bool hasSufficientMemory; + + bool operator<(const SortableModel& other) const { + return name < other.name; + } +}; + +class ModelManagerModal { +public: + ModelManagerModal() : m_searchText(""), m_shouldFocusSearch(false), m_showSufficientMemoryOnly(false) {} + + void render(bool& showDialog, bool allowSwitching = true) { + auto& manager = Model::ModelManager::getInstance(); + + // Update sorted models when: + // - The modal is opened for the first time + // - A model is downloaded, deleted, or its status changed + bool needsUpdate = false; + + if (showDialog && !m_wasShowing) { + // Modal just opened - refresh the model list + needsUpdate = true; + // Focus the search field when the modal is opened + m_shouldFocusSearch = true; + } + + // Check for changes in download status + const auto& models = manager.getModels(); + if (models.size() != m_lastModelCount) { + // The model count changed + needsUpdate = true; + } + + // Check if a model was added through the custom model form + if (m_addCustomModelModal.wasModelAdded()) { + needsUpdate = true; + m_addCustomModelModal.resetModelAddedFlag(); + } + + // Check for changes in downloaded status + if (!needsUpdate) { + std::unordered_set currentDownloaded; + + for (size_t i = 0; i < models.size(); ++i) { + // Check if ANY variant is downloaded instead of just the current one + if (manager.isAnyVariantDownloaded(static_cast(i))) { + currentDownloaded.insert(models[i].name); // Don't need to add variant to the key + } + } + + if (currentDownloaded != m_lastDownloadedStatus) { + needsUpdate = true; + m_lastDownloadedStatus = std::move(currentDownloaded); + } + } + + if (needsUpdate) { + updateSortedModels(); + m_lastModelCount = models.size(); + filterModels(); // Apply the current search filter to the updated models + } + + m_wasShowing = showDialog; + + ImVec2 windowSize = ImGui::GetWindowSize(); + if (windowSize.x == 0) windowSize = ImGui::GetMainViewport()->Size; + const float targetWidth = windowSize.x; + float availableWidth = targetWidth - (2 * ModelManagerConstants::padding); + + int numCards = static_cast(availableWidth / (ModelManagerConstants::cardWidth + ModelManagerConstants::cardSpacing)); + float modalWidth = (numCards * (ModelManagerConstants::cardWidth + ModelManagerConstants::cardSpacing)) + (2 * ModelManagerConstants::padding); + if (targetWidth - modalWidth > (ModelManagerConstants::cardWidth + ModelManagerConstants::cardSpacing) * 0.5f) { + ++numCards; + modalWidth = (numCards * (ModelManagerConstants::cardWidth + ModelManagerConstants::cardSpacing)) + (2 * ModelManagerConstants::padding); + } + ImVec2 modalSize = ImVec2(modalWidth, windowSize.y * ModelManagerConstants::modalVerticalScale); + + auto renderCards = [numCards, this, &manager, allowSwitching]() { + const auto& models = manager.getModels(); + + // Render search field at the top + renderSearchField(); + ImGui::SetCursorPosX(ImGui::GetCursorPosX() + 12.0F); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 12.0F); + + ButtonConfig addCustomModelBtn; + addCustomModelBtn.id = "##addCustomModel"; + addCustomModelBtn.label = "Add Custom Model"; + addCustomModelBtn.icon = ICON_CI_PLUS; + addCustomModelBtn.backgroundColor = ImVec4(0.3, 0.3, 0.3, 0.3); + addCustomModelBtn.hoverColor = ImVec4(0.2, 0.2, 0.2, 0.2); + addCustomModelBtn.size = ImVec2(180, 32.0f); + addCustomModelBtn.onClick = [this]() { + m_addCustomModelModalOpen = true; + }; + Button::render(addCustomModelBtn); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 12.0F); + + if (m_addCustomModelModalOpen) { + m_addCustomModelModal.render(m_addCustomModelModalOpen); + } + + if (m_deleteModalOpen) { + m_deleteModal.render(m_deleteModalOpen); + } + + LabelConfig downloadedSectionLabel; + downloadedSectionLabel.id = "##downloadedModelsHeader"; + downloadedSectionLabel.label = "Downloaded Models"; + downloadedSectionLabel.size = ImVec2(0, 0); + downloadedSectionLabel.fontSize = FontsManager::LG; + downloadedSectionLabel.alignment = Alignment::LEFT; + + ImGui::SetCursorPos(ImVec2(ModelManagerConstants::padding, ImGui::GetCursorPosY())); + Label::render(downloadedSectionLabel); + + // Add the "Show models with sufficient memory only" checkbox using custom widget + ImGui::SameLine(); + ImGui::SetCursorPosX(ImGui::GetContentRegionAvail().x - 32); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 2.0f); + + LabelConfig memoryFilterLabel; + memoryFilterLabel.id = "##memoryFilterCheckbox_label"; + memoryFilterLabel.label = "Show sufficient memory only"; + memoryFilterLabel.size = ImVec2(0, 0); + memoryFilterLabel.fontType = FontsManager::REGULAR; + memoryFilterLabel.fontSize = FontsManager::MD; + memoryFilterLabel.alignment = Alignment::LEFT; + + ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 4.0F); + + Label::render(memoryFilterLabel); + + ImGui::SameLine(); + + ButtonConfig memoryFilterBtn; + memoryFilterBtn.id = "##memoryFilterCheckbox"; + memoryFilterBtn.icon = m_showSufficientMemoryOnly ? ICON_CI_CHECK : ICON_CI_CLOSE; + memoryFilterBtn.textColor = m_showSufficientMemoryOnly ? ImVec4(1, 1, 1, 1) : ImVec4(0.6f, 0.6f, 0.6f, 1.0f); + memoryFilterBtn.fontSize = FontsManager::SM; + memoryFilterBtn.size = ImVec2(24, 24); + memoryFilterBtn.backgroundColor = m_showSufficientMemoryOnly ? Config::Color::PRIMARY : RGBAToImVec4(60, 60, 60, 255); + memoryFilterBtn.hoverColor = m_showSufficientMemoryOnly ? RGBAToImVec4(53, 132, 228, 255) : RGBAToImVec4(80, 80, 80, 255); + memoryFilterBtn.activeColor = m_showSufficientMemoryOnly ? RGBAToImVec4(26, 95, 180, 255) : RGBAToImVec4(100, 100, 100, 255); + memoryFilterBtn.tooltip = "Only show models that can run with your available memory"; + memoryFilterBtn.onClick = [this]() { + m_showSufficientMemoryOnly = !m_showSufficientMemoryOnly; + filterModels(); // Reapply filters when checkbox changes + }; + Button::render(memoryFilterBtn); + + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 10.0f); + + // Count downloaded models and check if we have any + bool hasDownloadedModels = false; + int downloadedCardCount = 0; + + // First pass to check if we have any downloaded models + for (const auto& sortableModel : m_filteredModels) { + // Check if ANY variant is downloaded instead of just current variant + if (manager.isAnyVariantDownloaded(sortableModel.index)) { + hasDownloadedModels = true; + break; + } + } + + // Render downloaded models + if (hasDownloadedModels) { + for (const auto& sortableModel : m_filteredModels) { + // Check if ANY variant is downloaded instead of just current variant + if (manager.isAnyVariantDownloaded(sortableModel.index)) { + if (downloadedCardCount % numCards == 0) { + ImGui::SetCursorPos(ImVec2(ModelManagerConstants::padding, + ImGui::GetCursorPosY() + (downloadedCardCount > 0 ? ModelManagerConstants::cardSpacing : 0))); + } + + ModelCardRenderer card(sortableModel.index, models[sortableModel.index], + [this](int index, const std::string& variant) { + m_deleteModal.setModel(index, variant); + m_deleteModalOpen = true; + }, "downloaded", allowSwitching); + card.render(); + + if ((downloadedCardCount + 1) % numCards != 0) { + ImGui::SameLine(0.0f, ModelManagerConstants::cardSpacing); + } + + downloadedCardCount++; + } + } + + // Add spacing before the next section + if (downloadedCardCount % numCards != 0) { + ImGui::NewLine(); + } + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + ModelManagerConstants::sectionSpacing); + } + else { + // Show a message if no downloaded models + LabelConfig noModelsLabel; + noModelsLabel.id = "##noDownloadedModels"; + noModelsLabel.label = m_searchText.empty() ? + "No downloaded models yet. Download models from the section below." : + "No downloaded models match your search. Try a different search term."; + noModelsLabel.size = ImVec2(0, 0); + noModelsLabel.fontType = FontsManager::ITALIC; + noModelsLabel.fontSize = FontsManager::MD; + noModelsLabel.alignment = Alignment::LEFT; + + ImGui::SetCursorPosX(ModelManagerConstants::padding); + Label::render(noModelsLabel); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + ModelManagerConstants::sectionSpacing); + } + + // Separator between sections + ImGui::SetCursorPosX(ModelManagerConstants::padding); + ImGui::PushStyleColor(ImGuiCol_Separator, ImVec4(0.3f, 0.3f, 0.3f, 0.5f)); + ImGui::Separator(); + ImGui::PopStyleColor(); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 10.0f); + + // Render "Available Models" section header with custom checkbox + LabelConfig availableSectionLabel; + availableSectionLabel.id = "##availableModelsHeader"; + availableSectionLabel.label = "Available Models"; + availableSectionLabel.size = ImVec2(0, 0); + availableSectionLabel.fontSize = FontsManager::LG; + availableSectionLabel.alignment = Alignment::LEFT; + + ImGui::SetCursorPosX(ModelManagerConstants::padding); + Label::render(availableSectionLabel); + ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 10.0f); + + // Check if we have any available models that match the search and filters + if (m_filteredModels.empty()) { + LabelConfig noModelsLabel; + noModelsLabel.id = "##noAvailableModels"; + if (!m_searchText.empty()) { + noModelsLabel.label = "No models match your search. Try a different search term."; + } + else if (m_showSufficientMemoryOnly) { + noModelsLabel.label = "No models with sufficient memory found. Try disabling the memory filter."; + } + else { + noModelsLabel.label = "No models available."; + } + noModelsLabel.size = ImVec2(0, 0); + noModelsLabel.fontType = FontsManager::ITALIC; + noModelsLabel.fontSize = FontsManager::MD; + noModelsLabel.alignment = Alignment::LEFT; + + ImGui::SetCursorPosX(ModelManagerConstants::padding); + Label::render(noModelsLabel); + } + else { + // Render all models (available for download) + for (size_t i = 0; i < m_filteredModels.size(); ++i) { + if (i % numCards == 0) { + ImGui::SetCursorPos(ImVec2(ModelManagerConstants::padding, + ImGui::GetCursorPosY() + (i > 0 ? ModelManagerConstants::cardSpacing : 0))); + } + + ModelCardRenderer card(m_filteredModels[i].index, models[m_filteredModels[i].index], + [this](int index, const std::string& variant) { + m_deleteModal.setModel(index, variant); + m_deleteModalOpen = true; + }); + card.render(); + + if ((i + 1) % numCards != 0 && i < m_filteredModels.size() - 1) { + ImGui::SameLine(0.0f, ModelManagerConstants::cardSpacing); + } + } + } + }; + + ModalConfig config{ + "Model Manager", + "Model Manager", + modalSize, + renderCards, + showDialog + }; + config.padding = ImVec2(ModelManagerConstants::padding, 8.0f); + ModalWindow::render(config); + + if (m_needsUpdateAfterDelete && !m_deleteModalOpen) { + updateSortedModels(); + filterModels(); // Apply search filter after updating models + m_needsUpdateAfterDelete = false; + } + + if (!ImGui::IsPopupOpen(config.id.c_str())) { + showDialog = false; + } + } + +private: + DeleteModelModalComponent m_deleteModal; + bool m_deleteModalOpen = false; + bool m_wasShowing = false; + bool m_needsUpdateAfterDelete = false; + size_t m_lastModelCount = 0; + std::unordered_set m_lastDownloadedStatus; + std::vector m_sortedModels; + std::vector m_filteredModels; + + // Search related variables + std::string m_searchText; + bool m_shouldFocusSearch; + + // Memory filter checkbox state + bool m_showSufficientMemoryOnly; + + + AddCustomModelModalComponent m_addCustomModelModal; + bool m_addCustomModelModalOpen = false; + + void updateSortedModels() { + auto& manager = Model::ModelManager::getInstance(); + const auto& models = manager.getModels(); + + // Clear and rebuild the sorted model list + m_sortedModels.clear(); + m_sortedModels.reserve(models.size()); + + for (size_t i = 0; i < models.size(); ++i) { + // Check memory sufficiency status + float modelMemoryRequirement = 0; + float kvramMemoryRequirement = 0; + bool hasSufficientMemory = manager.hasEnoughMemoryForModel( + models[i].name, modelMemoryRequirement, kvramMemoryRequirement); + + // Store the index, name, and memory status + m_sortedModels.push_back({ + static_cast(i), + models[i].name, + hasSufficientMemory + }); + } + + // Sort models alphabetically by name + std::sort(m_sortedModels.begin(), m_sortedModels.end()); + + // Initialize filtered models with all models when sort is updated + filterModels(); + } + + // Filter models based on search text and memory filter + void filterModels() { + m_filteredModels.clear(); + auto& manager = Model::ModelManager::getInstance(); + const auto& models = manager.getModels(); + + // Convert search text to lowercase for case-insensitive comparison + std::string searchLower = m_searchText; + std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), + [](unsigned char c) { return std::tolower(c); }); + + // Filter models based on name OR author containing the search text + // AND the memory sufficiency if that filter is enabled + for (const auto& model : m_sortedModels) { + // Skip models that don't have sufficient memory if filter is enabled + if (m_showSufficientMemoryOnly && !model.hasSufficientMemory) { + continue; + } + + // If search text is empty, include the model (it already passed the memory filter) + if (searchLower.empty()) { + m_filteredModels.push_back(model); + continue; + } + + // Get the model data using the stored index + const auto& modelData = models[model.index]; + + // Convert name and author to lowercase for case-insensitive comparison + std::string nameLower = modelData.name; + std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), + [](unsigned char c) { return std::tolower(c); }); + + std::string authorLower = modelData.author; + std::transform(authorLower.begin(), authorLower.end(), authorLower.begin(), + [](unsigned char c) { return std::tolower(c); }); + + // Add model to filtered results if either name OR author contains the search text + if (nameLower.find(searchLower) != std::string::npos || + authorLower.find(searchLower) != std::string::npos) { + m_filteredModels.push_back(model); + } + } + } + + // New method: Render search field + void renderSearchField() { + ImGui::SetCursorPosX(ModelManagerConstants::padding); + + // Create and configure search input field + InputFieldConfig searchConfig( + "##modelSearch", + ImVec2(ImGui::GetContentRegionAvail().x, 32.0f), + m_searchText, + m_shouldFocusSearch + ); + searchConfig.placeholderText = "Search models..."; + searchConfig.processInput = [this](const std::string& text) { + // No need to handle submission specifically as we'll filter on every change + }; + + // Style the search field + searchConfig.backgroundColor = RGBAToImVec4(34, 34, 34, 255); + searchConfig.hoverColor = RGBAToImVec4(44, 44, 44, 255); + searchConfig.activeColor = RGBAToImVec4(54, 54, 54, 255); + + // Render the search field + InputField::render(searchConfig); + + // Filter models whenever search text changes + static std::string lastSearch; + if (lastSearch != m_searchText) { + lastSearch = m_searchText; + filterModels(); + } + } +}; \ No newline at end of file diff --git a/include/ui/server/deployment_settings.hpp b/include/ui/server/deployment_settings.hpp index 4a59aea..4017ad3 100644 --- a/include/ui/server/deployment_settings.hpp +++ b/include/ui/server/deployment_settings.hpp @@ -274,7 +274,7 @@ class DeploymentSettingsSidebar { void render() { ImGuiIO& io = ImGui::GetIO(); - const float sidebarHeight = io.DisplaySize.y - Config::TITLE_BAR_HEIGHT; + const float sidebarHeight = io.DisplaySize.y - Config::TITLE_BAR_HEIGHT - 40 - Config::FOOTER_HEIGHT; // Right sidebar window ImGui::SetNextWindowPos(ImVec2(io.DisplaySize.x - m_width, Config::TITLE_BAR_HEIGHT + 40), ImGuiCond_Always); diff --git a/include/ui/server/server_logs.hpp b/include/ui/server/server_logs.hpp index e9e7c22..963886b 100644 --- a/include/ui/server/server_logs.hpp +++ b/include/ui/server/server_logs.hpp @@ -2,7 +2,7 @@ #include "imgui.h" #include "ui/widgets.hpp" -#include "ui/chat/model_manager_modal.hpp" +#include "ui/model_manager_modal.hpp" #include "model/model_manager.hpp" #include "model/server_state_manager.hpp" @@ -33,7 +33,7 @@ class ServerLogViewer { ImGui::PushStyleVar(ImGuiStyleVar_WindowBorderSize, 0.0F); ImGui::SetNextWindowPos(ImVec2(0, Config::TITLE_BAR_HEIGHT), ImGuiCond_Always); - ImGui::SetNextWindowSize(ImVec2(io.DisplaySize.x - sidebarWidth, io.DisplaySize.y - Config::TITLE_BAR_HEIGHT), ImGuiCond_Always); + ImGui::SetNextWindowSize(ImVec2(io.DisplaySize.x - sidebarWidth, io.DisplaySize.y - Config::TITLE_BAR_HEIGHT - Config::FOOTER_HEIGHT), ImGuiCond_Always); ImGui::Begin("Server Logs", nullptr, window_flags); ImGui::PopStyleVar(); @@ -143,7 +143,7 @@ class ServerLogViewer { Button::render(copyButtonConfig); } - m_modelManagerModal.render(m_modelManagerModalOpen); + m_modelManagerModal.render(m_modelManagerModalOpen, false); } ImGui::SetCursorPosY(ImGui::GetCursorPosY() + 12); diff --git a/include/ui/status_bar.hpp b/include/ui/status_bar.hpp new file mode 100644 index 0000000..3d3b86c --- /dev/null +++ b/include/ui/status_bar.hpp @@ -0,0 +1,190 @@ +#pragma once + +#include "../config.hpp" +#include "../system_monitor.hpp" +#include "widgets.hpp" +#include "fonts.hpp" +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#include +#else +#include +#include +#include +#endif + +class StatusBar { +public: + StatusBar() + : lastUpdateTime(std::chrono::steady_clock::now()) + , updateInterval(1000) + { + // Get the username once during initialization + getCurrentUsername(); + updateCurrentTime(); + } + + void render() { + ImGuiIO& io = ImGui::GetIO(); + + // Get the instance of SystemMonitor + SystemMonitor& sysMonitor = SystemMonitor::getInstance(); + sysMonitor.update(); + + // Only update metrics occasionally to reduce CPU impact + auto currentTime = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(currentTime - lastUpdateTime).count() > updateInterval) { + // Update SystemMonitor to get fresh stats + sysMonitor.update(); + updateCurrentTime(); + lastUpdateTime = currentTime; + } + + // Calculate status bar position and size + ImVec2 windowPos(0, io.DisplaySize.y - Config::FOOTER_HEIGHT); + ImVec2 windowSize(io.DisplaySize.x, Config::FOOTER_HEIGHT); + + // Begin a window with specific position and size + ImGui::SetNextWindowPos(windowPos); + ImGui::SetNextWindowSize(windowSize); + + // Status bar window flags + ImGuiWindowFlags window_flags = + ImGuiWindowFlags_NoTitleBar | + ImGuiWindowFlags_NoResize | + ImGuiWindowFlags_NoMove | + ImGuiWindowFlags_NoScrollbar | + ImGuiWindowFlags_NoSavedSettings | + ImGuiWindowFlags_NoBringToFrontOnFocus; + + // Minimal styling + ImGui::PushStyleVar(ImGuiStyleVar_WindowRounding, 0.0f); + ImGui::PushStyleVar(ImGuiStyleVar_WindowBorderSize, 1.0f); + ImGui::PushStyleColor(ImGuiCol_WindowBg, ImVec4(0.1f, 0.1f, 0.1f, 0.4f)); + + if (ImGui::Begin("##StatusBar", nullptr, window_flags)) { + // Left side: Version + LabelConfig versionLabel; + versionLabel.id = "##versionLabel"; + versionLabel.label = "Version: " + std::string(APP_VERSION); + versionLabel.size = ImVec2(200, 20); + versionLabel.fontSize = FontsManager::SM; + + ImGui::SetCursorPosY(ImGui::GetCursorPosY() - 10); + + Label::render(versionLabel); + + ImGui::SameLine(); + + // Get metrics from SystemMonitor + float cpuUsage = sysMonitor.getCpuUsagePercentage(); + size_t memoryUsageMB = sysMonitor.getUsedMemoryByProcess() / (1024 * 1024); + + // Format the CPU usage with one decimal place + std::stringstream cpuSS; + cpuSS << std::fixed << std::setprecision(1) << cpuUsage; + + // Prepare buttons for system metrics + ButtonConfig cpuUsageLabel; + cpuUsageLabel.id = "##cpuUsageLabel"; + cpuUsageLabel.label = "CPU: " + cpuSS.str() + "%"; + cpuUsageLabel.size = ImVec2(100, 20); + cpuUsageLabel.fontSize = FontsManager::SM; + + ButtonConfig memoryUsageLabel; + memoryUsageLabel.id = "##memoryUsageLabel"; + memoryUsageLabel.label = "Memory: " + std::to_string(memoryUsageMB) + " MB"; + memoryUsageLabel.size = ImVec2(150, 20); + memoryUsageLabel.fontSize = FontsManager::SM; + + // Create buttons for GPU metrics if available + std::vector buttonConfigs = { cpuUsageLabel, memoryUsageLabel }; + + // Right-align the time display + float contentWidth = ImGui::GetContentRegionAvail().x; + float timeWidth = 150; // Approximate width needed for time display + + if (sysMonitor.hasGpuSupport()) { + size_t gpuUsageMB = sysMonitor.getUsedGpuMemoryByProcess() / (1024 * 1024); + ButtonConfig gpuUsageLabel; + gpuUsageLabel.id = "##gpuUsageLabel"; + gpuUsageLabel.label = "GPU Memory: " + std::to_string(gpuUsageMB) + " MB"; + gpuUsageLabel.size = ImVec2(180, 20); + gpuUsageLabel.fontSize = FontsManager::SM; + buttonConfigs.push_back(gpuUsageLabel); + timeWidth += 180; + } + + Button::renderGroup(buttonConfigs, contentWidth - timeWidth, + ImGui::GetCursorPosY() - 2, 0); + } + ImGui::End(); + + ImGui::PopStyleVar(2); + ImGui::PopStyleColor(); + } + +private: + std::chrono::steady_clock::time_point lastUpdateTime; + int updateInterval; + std::string username; + char timeBuffer[64]; + + void getCurrentUsername() { +#ifdef _WIN32 + // Windows implementation + char buffer[UNLEN + 1]; + DWORD size = UNLEN + 1; + if (GetUserNameA(buffer, &size)) { + username = buffer; + } + else { + username = "unknown"; + } +#else + // Unix/Linux/macOS implementation + const char* user = getenv("USER"); + if (user != nullptr) { + username = user; + } + else { + // Fallback to getpwuid if USER environment variable is not available + struct passwd* pwd = getpwuid(geteuid()); + if (pwd != nullptr) { + username = pwd->pw_name; + } + else { + username = "unknown"; + } + } +#endif + } + + void updateCurrentTime() { + auto now = std::chrono::system_clock::now(); + auto now_time_t = std::chrono::system_clock::to_time_t(now); + + // Format time as UTC with the specified format + std::tm tm_utc; +#ifdef _WIN32 + gmtime_s(&tm_utc, &now_time_t); +#else + gmtime_r(&now_time_t, &tm_utc); +#endif + + std::stringstream ss; + ss << std::put_time(&tm_utc, "%Y-%m-%d %H:%M:%S"); + std::string str = ss.str(); + + // Copy to our buffer + strncpy(timeBuffer, str.c_str(), sizeof(timeBuffer) - 1); + timeBuffer[sizeof(timeBuffer) - 1] = '\0'; + } +}; \ No newline at end of file diff --git a/include/ui/tab_manager.hpp b/include/ui/tab_manager.hpp index 16f0bb2..420df7d 100644 --- a/include/ui/tab_manager.hpp +++ b/include/ui/tab_manager.hpp @@ -24,7 +24,6 @@ class ITab { virtual const char* getIcon() const = 0; }; -// Update ChatTab to implement the new methods class ChatTab : public ITab { public: ChatTab() @@ -81,7 +80,6 @@ class ServerTab : public ITab { DeploymentSettingsSidebar deploymentSettingsSidebar; }; -// Update TabManager to handle tab activation/deactivation class TabManager { public: TabManager() : activeTabIndex(0) {} diff --git a/include/ui/widgets.hpp b/include/ui/widgets.hpp index c2fc39e..f6e4241 100644 --- a/include/ui/widgets.hpp +++ b/include/ui/widgets.hpp @@ -965,6 +965,8 @@ namespace ModalWindow } ImGui::PushStyleColor(ImGuiCol_ModalWindowDimBg, ImVec4(0.0F, 0.0F, 0.0F, 0.5F)); + ImGui::PushStyleColor(ImGuiCol_PopupBg, ImVec4(0.075F, 0.075F, 0.075F, 1.0F)); + ImGui::PushStyleColor(ImGuiCol_ScrollbarBg, ImVec4(0, 0, 0, 0)); ImGui::SetNextWindowPos(ImGui::GetMainViewport()->GetCenter(), ImGuiCond_Always, ImVec2(0.5F, 0.5F)); ImGui::SetNextWindowSize(config.size); @@ -1006,7 +1008,7 @@ namespace ModalWindow ImGui::EndPopup(); } - ImGui::PopStyleColor(); + ImGui::PopStyleColor(3); } } diff --git a/include/window/gradient_background.hpp b/include/window/gradient_background.hpp deleted file mode 100644 index fa5178f..0000000 --- a/include/window/gradient_background.hpp +++ /dev/null @@ -1,267 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -GLuint g_shaderProgram = 0; -GLuint g_gradientTexture = 0; - -const char* g_quadVertexShaderSource = R"( -#version 330 core -layout(location = 0) in vec2 aPos; -layout(location = 1) in vec2 aTexCoord; - -out vec2 TexCoord; - -void main() -{ - TexCoord = aTexCoord; - gl_Position = vec4(aPos, 0.0, 1.0); -} -)"; - -const char* g_quadFragmentShaderSource = R"( -#version 330 core -in vec2 TexCoord; -out vec4 FragColor; - -uniform sampler2D gradientTexture; -uniform float uTransitionProgress; - -void main() -{ - vec4 color = texture(gradientTexture, TexCoord); - color.a *= uTransitionProgress; // Adjust the alpha based on transition progress - FragColor = color; -} -)"; - -GLuint g_quadVAO = 0; -GLuint g_quadVBO = 0; -GLuint g_quadEBO = 0; - -namespace GradientBackground { - - void checkShaderCompileErrors(GLuint shader, const std::string& type) { - GLint success; - GLchar infoLog[1024]; - if (type != "PROGRAM") { - glGetShaderiv(shader, GL_COMPILE_STATUS, &success); - if (!success) { - glGetShaderInfoLog(shader, 1024, NULL, infoLog); - std::cerr << "| ERROR::SHADER_COMPILATION_ERROR of type: " << type << "|\n" - << infoLog << "\n -- --------------------------------------------------- -- " << std::endl; - } - } - } - - void checkProgramLinkErrors(GLuint program) { - GLint success; - GLchar infoLog[1024]; - glGetProgramiv(program, GL_LINK_STATUS, &success); - if (!success) { - glGetProgramInfoLog(program, 1024, NULL, infoLog); - std::cerr << "| ERROR::PROGRAM_LINKING_ERROR |\n" - << infoLog << "\n -- --------------------------------------------------- -- " << std::endl; - } - } - - void generateGradientTexture(int width, int height) { - // Delete the existing texture if it exists - if (g_gradientTexture != 0) { - glDeleteTextures(1, &g_gradientTexture); - } - - // Create a new texture - glGenTextures(1, &g_gradientTexture); - glBindTexture(GL_TEXTURE_2D, g_gradientTexture); - - // Allocate a buffer to hold the gradient data - std::vector gradientData(width * height * 4); - - // Define the start and end colors (RGBA) - ImVec4 colorTopLeft = ImVec4(0.05f, 0.07f, 0.12f, 1.0f); // Dark Blue - ImVec4 colorBottomRight = ImVec4(0.16f, 0.14f, 0.08f, 1.0f); // Dark Green - - // Generate the gradient data - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - float t_x = static_cast(x) / (width - 1); - float t_y = static_cast(y) / (height - 1); - float t = (t_x + t_y) / 2.0f; // Diagonal gradient - - ImVec4 pixelColor = ImLerp(colorTopLeft, colorBottomRight, t); - - unsigned char r = static_cast(pixelColor.x * 255); - unsigned char g = static_cast(pixelColor.y * 255); - unsigned char b = static_cast(pixelColor.z * 255); - unsigned char a = static_cast(pixelColor.w * 255); - - int index = (y * width + x) * 4; - gradientData[index + 0] = r; - gradientData[index + 1] = g; - gradientData[index + 2] = b; - gradientData[index + 3] = a; - } - } - - // Upload the gradient data to the texture - glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, width, height, 0, GL_RGBA, GL_UNSIGNED_BYTE, gradientData.data()); - - // Set texture parameters - glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); - glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); - - // Unbind the texture - glBindTexture(GL_TEXTURE_2D, 0); - } - - GLuint compileShader(GLenum type, const char* source) { - GLuint shader = glCreateShader(type); - glShaderSource(shader, 1, &source, NULL); - glCompileShader(shader); - - // Check for compile errors - checkShaderCompileErrors(shader, (type == GL_VERTEX_SHADER) ? "VERTEX" : "FRAGMENT"); - - return shader; - } - - GLuint createShaderProgram(const char* vertexSource, const char* fragmentSource) { - GLuint vertexShader = compileShader(GL_VERTEX_SHADER, vertexSource); - GLuint fragmentShader = compileShader(GL_FRAGMENT_SHADER, fragmentSource); - - // Link shaders into a program - GLuint program = glCreateProgram(); - glAttachShader(program, vertexShader); - glAttachShader(program, fragmentShader); - glLinkProgram(program); - - // Check for linking errors - checkProgramLinkErrors(program); - - // Clean up shaders - glDeleteShader(vertexShader); - glDeleteShader(fragmentShader); - - return program; - } - - void setupFullScreenQuad() { - float quadVertices[] = { - // Positions // Texture Coords - -1.0f, 1.0f, 0.0f, 1.0f, // Top-left - -1.0f, -1.0f, 0.0f, 0.0f, // Bottom-left - 1.0f, -1.0f, 1.0f, 0.0f, // Bottom-right - 1.0f, 1.0f, 1.0f, 1.0f // Top-right - }; - - unsigned int quadIndices[] = { - 0, 1, 2, // First triangle - 0, 2, 3 // Second triangle - }; - - glGenVertexArrays(1, &g_quadVAO); - glGenBuffers(1, &g_quadVBO); - glGenBuffers(1, &g_quadEBO); - - glBindVertexArray(g_quadVAO); - - // Vertex Buffer - glBindBuffer(GL_ARRAY_BUFFER, g_quadVBO); - glBufferData(GL_ARRAY_BUFFER, sizeof(quadVertices), quadVertices, GL_STATIC_DRAW); - - // Element Buffer - glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, g_quadEBO); - glBufferData(GL_ELEMENT_ARRAY_BUFFER, sizeof(quadIndices), quadIndices, GL_STATIC_DRAW); - - // Position Attribute - glEnableVertexAttribArray(0); - glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 4 * sizeof(float), (void*)0); - - // Texture Coordinate Attribute - glEnableVertexAttribArray(1); - glVertexAttribPointer(1, 2, GL_FLOAT, GL_FALSE, 4 * sizeof(float), (void*)(2 * sizeof(float))); - - glBindVertexArray(0); - } - - void renderGradientBackground(int display_w, int display_h, float transitionProgress, float easedProgress) { - // Set the viewport and clear the screen - glViewport(0, 0, display_w, display_h); - glClearColor(0, 0, 0, 0); // Clear with transparent color if blending is enabled - glClear(GL_COLOR_BUFFER_BIT); - - // Disable depth test and face culling - glDisable(GL_DEPTH_TEST); - glDisable(GL_CULL_FACE); - - // Render the gradient texture as background - if (transitionProgress > 0.0f) { - // Enable blending - glEnable(GL_BLEND); - glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA); - - // Use the shader program - glUseProgram(g_shaderProgram); - - // Bind the gradient texture - glActiveTexture(GL_TEXTURE0); - glBindTexture(GL_TEXTURE_2D, g_gradientTexture); - - // Set the sampler uniform - glUniform1i(glGetUniformLocation(g_shaderProgram, "gradientTexture"), 0); - - // Set the transition progress uniform - GLint locTransitionProgress = glGetUniformLocation(g_shaderProgram, "uTransitionProgress"); - glUniform1f(locTransitionProgress, easedProgress); // Use easedProgress - - // Render the full-screen quad - glBindVertexArray(g_quadVAO); - glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_INT, 0); - glBindVertexArray(0); - - // Unbind the shader program - glUseProgram(0); - - // Disable blending if necessary - // glDisable(GL_BLEND); - } - } - - void CleanUp() - { - if (g_gradientTexture != 0) - { - glDeleteTextures(1, &g_gradientTexture); - g_gradientTexture = 0; - } - if (g_quadVAO != 0) - { - glDeleteVertexArrays(1, &g_quadVAO); - g_quadVAO = 0; - } - if (g_quadVBO != 0) - { - glDeleteBuffers(1, &g_quadVBO); - g_quadVBO = 0; - } - if (g_quadEBO != 0) - { - glDeleteBuffers(1, &g_quadEBO); - g_quadEBO = 0; - } - if (g_shaderProgram != 0) - { - glDeleteProgram(g_shaderProgram); - g_shaderProgram = 0; - } - } - -} // namespace GradientBackground \ No newline at end of file diff --git a/installer/script.nsi b/installer/script.nsi index 0a0ad76..aba8fbb 100644 --- a/installer/script.nsi +++ b/installer/script.nsi @@ -23,7 +23,7 @@ Var IsUpgrade ;----------------------------------- ; Embed version info (metadata) ;----------------------------------- -!define VERSION "0.1.7.0" +!define VERSION "0.1.8.0" VIProductVersion "${VERSION}" VIAddVersionKey "ProductName" "Kolosal AI Installer" VIAddVersionKey "CompanyName" "Genta Technology" @@ -142,15 +142,15 @@ Section "Kolosal AI" SecKolosalAI ; Force overwrite of existing files SetOverwrite on - ; If this is an upgrade, remove old files first (except chat history) + ; If this is an upgrade, remove old files first (except chat history and models folder) ${If} $IsUpgrade == "true" ; Display upgrade message DetailPrint "Upgrading from version $OldVersion to $NewVersion" - ; Remove previous program files but keep the directory structure + ; Remove previous program files but keep the models folder intact to preserve user-downloaded models. RMDir /r "$INSTDIR\assets" RMDir /r "$INSTDIR\fonts" - RMDir /r "$INSTDIR\models" + ; NOTE: Do NOT remove "$INSTDIR\models" so that any user-downloaded models are not deleted. Delete "$INSTDIR\*.dll" Delete "$INSTDIR\*.exe" Delete "$INSTDIR\LICENSE" @@ -171,8 +171,8 @@ Section "Kolosal AI" SecKolosalAI File "libcrypto-3-x64.dll" File "libssl-3-x64.dll" File "libcurl.dll" - FILE "kolosal_server.dll" - FILE "vcomp140.dll" + File "kolosal_server.dll" + File "vcomp140.dll" File "LICENSE" ; Create and populate subdirectories @@ -184,11 +184,12 @@ Section "Kolosal AI" SecKolosalAI SetOutPath "$INSTDIR\fonts" File /r "fonts\*.*" + ; Update files within models folder without deleting the folder itself CreateDirectory "$INSTDIR\models" SetOutPath "$INSTDIR\models" File /r "models\*.*" - ; Create chat history directory if it doesn't exist + ; Create chat history directory if it doesn't exist (for a new install) ${If} $IsUpgrade == "false" CreateDirectory "$ChatHistoryDir" AccessControl::GrantOnFile "$ChatHistoryDir" "(S-1-5-32-545)" "FullAccess" @@ -211,9 +212,9 @@ Section "Kolosal AI" SecKolosalAI WriteRegStr HKLM "SOFTWARE\KolosalAI" "ChatHistory_Dir" "$ChatHistoryDir" WriteRegStr HKLM "SOFTWARE\KolosalAI" "Version" "${VERSION}" - WriteRegStr HKCU "Software\KolosalAI" "Install_Dir" "$INSTDIR" - WriteRegStr HKCU "Software\KolosalAI" "ChatHistory_Dir" "$ChatHistoryDir" - WriteRegStr HKCU "Software\KolosalAI" "Version" "${VERSION}" + WriteRegStr HKCU "SOFTWARE\KolosalAI" "Install_Dir" "$INSTDIR" + WriteRegStr HKCU "SOFTWARE\KolosalAI" "ChatHistory_Dir" "$ChatHistoryDir" + WriteRegStr HKCU "SOFTWARE\KolosalAI" "Version" "${VERSION}" ; Write uninstaller registry information WriteRegStr HKLM "SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\KolosalAI" "DisplayName" "Kolosal AI" @@ -278,7 +279,7 @@ keepChatHistory: ; Remove directories and files RMDir /r "$INSTDIR\assets" RMDir /r "$INSTDIR\fonts" - RMDir /r "$INSTDIR\models" + RMDir /r "$INSTDIR\models" ; For uninstallation, the entire models folder is removed. Delete "$INSTDIR\*.*" RMDir "$INSTDIR" diff --git a/kolosal-server b/kolosal-server index ea06fc2..fc4985d 160000 --- a/kolosal-server +++ b/kolosal-server @@ -1 +1 @@ -Subproject commit ea06fc2ad047fc0143e7b0f24f6e46398398a0b6 +Subproject commit fc4985d3929f84b36abf17e81e947981a8703fd4 diff --git a/models/bahasa-ai-4b.json b/models/bahasa-ai-4b.json index ea1e4ce..14e9bdd 100644 --- a/models/bahasa-ai-4b.json +++ b/models/bahasa-ai-4b.json @@ -1,6 +1,10 @@ { "name": "Bahasa AI 4B", "author": "Alibaba, Bahasa AI", + "hidden_size": 2560, + "attention_heads": 20, + "hidden_layers": 40, + "kv_heads": 20, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/bahasa-ai-4b/resolve/main/Bahasalab_Bahasa-4b-chat_f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 7.91 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/bahasa-ai-4b/resolve/main/Bahasalab_Bahasa-4b-chat_q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.2 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/bahasa-ai-4b/resolve/main/Bahasalab_Bahasa-4b-chat_q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 2.46 } } } \ No newline at end of file diff --git a/models/deepseek-r1-llama-8b.json b/models/deepseek-r1-llama-8b.json index 6cd5cea..c3e3619 100644 --- a/models/deepseek-r1-llama-8b.json +++ b/models/deepseek-r1-llama-8b.json @@ -1,6 +1,10 @@ { "name": "Deepseek R1 Llama 8B", "author": "Deepseek AI", + "hidden_size": 4096, + "attention_heads": 32, + "hidden_layers": 32, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Llama-8B/resolve/main/DeepSeek-R1-Distill-Llama-8B-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 16.1 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Llama-8B/resolve/main/DeepSeek-R1-Distill-Llama-8B-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.54 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Llama-8B/resolve/main/DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.92 } } } \ No newline at end of file diff --git a/models/deepseek-r1-qwen2.5-1.5b.json b/models/deepseek-r1-qwen2.5-1.5b.json index b078ba5..322101c 100644 --- a/models/deepseek-r1-qwen2.5-1.5b.json +++ b/models/deepseek-r1-qwen2.5-1.5b.json @@ -1,6 +1,10 @@ { "name": "Deepseek R1 Qwen2.5 1.5B", "author": "Deepseek AI", + "hidden_size": 1536, + "attention_heads": 12, + "hidden_layers": 28, + "kv_heads": 2, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-1.5B/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 3.56 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-1.5B/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.89 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-1.5B/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.12 } } } \ No newline at end of file diff --git a/models/deepseek-r1-qwen2.5-14b.json b/models/deepseek-r1-qwen2.5-14b.json index d696ba1..dafdc9d 100644 --- a/models/deepseek-r1-qwen2.5-14b.json +++ b/models/deepseek-r1-qwen2.5-14b.json @@ -1,6 +1,10 @@ { "name": "Deepseek R1 Qwen2.5 14B", "author": "Deepseek AI", + "hidden_size": 5120, + "attention_heads": 40, + "hidden_layers": 48, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-14B/resolve/main/DeepSeek-R1-Distill-Qwen-14B-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 29.5 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-14B/resolve/main/DeepSeek-R1-Distill-Qwen-14B-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 15.7 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-14B/resolve/main/DeepSeek-R1-Distill-Qwen-14B-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.99 } } } \ No newline at end of file diff --git a/models/deepseek-r1-qwen2.5-7b.json b/models/deepseek-r1-qwen2.5-7b.json index 01c6619..888c64c 100644 --- a/models/deepseek-r1-qwen2.5-7b.json +++ b/models/deepseek-r1-qwen2.5-7b.json @@ -1,6 +1,10 @@ { "name": "Deepseek R1 Qwen2.5 7B", "author": "Deepseek AI", + "hidden_size": 3584, + "attention_heads": 28, + "hidden_layers": 28, + "kv_heads": 4, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-7B/resolve/main/DeepSeek-R1-Distill-Qwen-7B-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 15.2 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-7B/resolve/main/DeepSeek-R1-Distill-Qwen-7B-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.1 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/Deepseek-R1-Qwen-7B/resolve/main/DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.68 } } } \ No newline at end of file diff --git a/models/gemma-2-2b.json b/models/gemma-2-2b.json index 747b01b..47e9bdb 100644 --- a/models/gemma-2-2b.json +++ b/models/gemma-2-2b.json @@ -1,6 +1,10 @@ { "name": "Gemma 2 2B", "author": "Google", + "hidden_size": 2304, + "attention_heads": 8, + "hidden_layers": 26, + "kv_heads": 4, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-2b/resolve/main/gemma-2-2b-it-f32.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 10.5 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-2b/resolve/main/gemma-2-2b-it-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 2.78 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-2b/resolve/main/gemma-2-2b-it-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.71 } } } \ No newline at end of file diff --git a/models/gemma-2-9b-sahabat.json b/models/gemma-2-9b-sahabat.json index 074152c..3f38742 100644 --- a/models/gemma-2-9b-sahabat.json +++ b/models/gemma-2-9b-sahabat.json @@ -1,6 +1,10 @@ { "name": "Gemma 2 9B Sahabat AI", "author": "Google, GoTo", + "hidden_size": 3584, + "attention_heads": 16, + "hidden_layers": 42, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-9b-sahabat-ai/resolve/main/gemma2-9b-cpt-sahabatai-v1-instruct.bf16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 18.5 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-9b-sahabat-ai/resolve/main/gemma2-9b-cpt-sahabatai-v1-instruct.Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 9.83 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-9b-sahabat-ai/resolve/main/gemma2-9b-cpt-sahabatai-v1-instruct.Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 5.76 } } } \ No newline at end of file diff --git a/models/gemma-2-9b.json b/models/gemma-2-9b.json index fed6330..0f83d8f 100644 --- a/models/gemma-2-9b.json +++ b/models/gemma-2-9b.json @@ -1,6 +1,10 @@ { "name": "Gemma 2 9B", "author": "Google", + "hidden_size": 3584, + "attention_heads": 16, + "hidden_layers": 42, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-9b/resolve/main/gemma-2-9b-it-f32.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 37.0 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-9b/resolve/main/gemma-2-9b-it-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 9.83 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-2-9b/resolve/main/gemma-2-9b-it-Q4_K_L.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 5.98 } } } \ No newline at end of file diff --git a/models/gemma-3-12b.json b/models/gemma-3-12b.json index 4670b0c..efc1da8 100644 --- a/models/gemma-3-12b.json +++ b/models/gemma-3-12b.json @@ -1,6 +1,10 @@ { "name": "Gemma 3 12B", "author": "Google", + "hidden_size": 3840, + "attention_heads": 16, + "hidden_layers": 48, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-12b/resolve/main/google_gemma-3-12b-it_f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 23.5 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-12b/resolve/main/google_gemma-3-12b-it_q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 12.5 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-12b/resolve/main/google_gemma-3-12b-it_q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 7.3 } } } \ No newline at end of file diff --git a/models/gemma-3-1b.json b/models/gemma-3-1b.json index 599ed4d..46734d7 100644 --- a/models/gemma-3-1b.json +++ b/models/gemma-3-1b.json @@ -1,6 +1,10 @@ { "name": "Gemma 3 1B", "author": "Google", + "hidden_size": 1152, + "attention_heads": 4, + "hidden_layers": 26, + "kv_heads": 1, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-1b/resolve/main/google_gemma-3-1b-it_f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 2.01 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-1b/resolve/main/google_gemma-3-1b-it_q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.07 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-1b/resolve/main/google_gemma-3-1b-it_q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.806 } } } \ No newline at end of file diff --git a/models/gemma-3-27b.json b/models/gemma-3-27b.json index 226b51b..4ad3e10 100644 --- a/models/gemma-3-27b.json +++ b/models/gemma-3-27b.json @@ -1,6 +1,10 @@ { "name": "Gemma 3 27B", "author": "Google", + "hidden_size": 5376, + "attention_heads": 32, + "hidden_layers": 62, + "kv_heads": 16, "variants": { "8-bit Quantized": { "type": "8-bit Quantized", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-27b/resolve/main/google_gemma-3-27b-it_q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 28.7 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-27b/resolve/main/google_gemma-3-27b-it_q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 16.5 } } } \ No newline at end of file diff --git a/models/gemma-3-4b.json b/models/gemma-3-4b.json index c989026..67ca02a 100644 --- a/models/gemma-3-4b.json +++ b/models/gemma-3-4b.json @@ -1,6 +1,10 @@ { "name": "Gemma 3 4B", "author": "Google", + "hidden_size": 2560, + "attention_heads": 16, + "hidden_layers": 34, + "kv_heads": 4, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-4b/resolve/main/google_gemma-3-4b-it_f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 7.77 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-4b/resolve/main/google_gemma-3-4b-it_q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.13 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/gemma-3-4b/resolve/main/google_gemma-3-4b-it_q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 2.49 } } } \ No newline at end of file diff --git a/models/llama-3-8b-sahabat.json b/models/llama-3-8b-sahabat.json index f85bd3e..eba81cf 100644 --- a/models/llama-3-8b-sahabat.json +++ b/models/llama-3-8b-sahabat.json @@ -1,6 +1,10 @@ { "name": "Llama 3 8B Sahabat AI", "author": "Meta, GoTo", + "hidden_size": 4096, + "attention_heads": 32, + "hidden_layers": 32, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3-8b-sahabat-ai/resolve/main/llama3-8b-cpt-sahabatai-v1-instruct.bf16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 16.1 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3-8b-sahabat-ai/resolve/main/llama3-8b-cpt-sahabatai-v1-instruct.Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.54 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3-8b-sahabat-ai/resolve/main/llama3-8b-cpt-sahabatai-v1-instruct.Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.92 } } } \ No newline at end of file diff --git a/models/llama-3.1-8b.json b/models/llama-3.1-8b.json index 146f8c4..7a00435 100644 --- a/models/llama-3.1-8b.json +++ b/models/llama-3.1-8b.json @@ -1,6 +1,10 @@ { "name": "Llama 3.1 8B", "author": "Meta", + "hidden_size": 4096, + "attention_heads": 32, + "hidden_layers": 32, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.1-8b/resolve/main/Meta-Llama-3.1-8B-Instruct.f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 16.1 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.1-8b/resolve/main/Meta-Llama-3.1-8B-Instruct.Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.54 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.1-8b/resolve/main/Meta-Llama-3.1-8B-Instruct.Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.92 } } } \ No newline at end of file diff --git a/models/llama-3.2-1b.json b/models/llama-3.2-1b.json index 781d0ad..714b17e 100644 --- a/models/llama-3.2-1b.json +++ b/models/llama-3.2-1b.json @@ -1,6 +1,10 @@ { "name": "Llama 3.2 1B", "author": "Meta", + "hidden_size": 2048, + "attention_heads": 32, + "hidden_layers": 16, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.2-1b/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 2.48 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.2-1b/resolve/main/Llama-3.2-1B-Instruct-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.32 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.2-1b/resolve/main/Llama-3.2-1B-Instruct-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.88 } } } \ No newline at end of file diff --git a/models/llama-3.2-3b.json b/models/llama-3.2-3b.json index 345ca0f..ff39a64 100644 --- a/models/llama-3.2-3b.json +++ b/models/llama-3.2-3b.json @@ -1,6 +1,10 @@ { "name": "Llama 3.2 3B", "author": "Meta", + "hidden_size": 3072, + "attention_heads": 24, + "hidden_layers": 28, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.2-3b/resolve/main/Llama-3.2-3B-Instruct-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 6.43 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.2-3b/resolve/main/Llama-3.2-3B-Instruct-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 3.42 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/llama-3.2-3b/resolve/main/Llama-3.2-3B-Instruct-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 2.02 } } } \ No newline at end of file diff --git a/models/phi-4-14b.json b/models/phi-4-14b.json index c819573..35b0fd5 100644 --- a/models/phi-4-14b.json +++ b/models/phi-4-14b.json @@ -1,6 +1,10 @@ { "name": "Phi 4 14B", "author": "Microsoft", + "hidden_size": 5120, + "attention_heads": 40, + "hidden_layers": 40, + "kv_heads": 10, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/phi-4/resolve/main/phi-4-F16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 29.1 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/phi-4/resolve/main/phi-4-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 15.6 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/phi-4/resolve/main/phi-4-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.89 } } } \ No newline at end of file diff --git a/models/phi-4-mini-3.8b.json b/models/phi-4-mini-3.8b.json index 0c08fa6..b20ef5e 100644 --- a/models/phi-4-mini-3.8b.json +++ b/models/phi-4-mini-3.8b.json @@ -1,6 +1,10 @@ { "name": "Phi 4 Mini 3.8B", "author": "Microsoft", + "hidden_size": 3072, + "attention_heads": 24, + "hidden_layers": 32, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/phi-4-mini/resolve/main/Phi-4-mini-instruct.BF16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 7.68 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/phi-4-mini/resolve/main/Phi-4-mini-instruct.Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.08 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/phi-4-mini/resolve/main/Phi-4-mini-instruct-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 2.49 } } } \ No newline at end of file diff --git a/models/qwen2.5-0.5b.json b/models/qwen2.5-0.5b.json index 4648131..32858cd 100644 --- a/models/qwen2.5-0.5b.json +++ b/models/qwen2.5-0.5b.json @@ -1,6 +1,10 @@ { "name": "Qwen2.5 0.5B", "author": "Alibaba", + "hidden_size": 896, + "attention_heads": 14, + "hidden_layers": 24, + "kv_heads": 2, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-0.5b/resolve/main/Qwen2.5-0.5B-Instruct-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.994 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-0.5b/resolve/main/Qwen2.5-0.5B-Instruct-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.531 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-0.5b/resolve/main/Qwen2.5-0.5B-Instruct-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.398 } } } \ No newline at end of file diff --git a/models/qwen2.5-1.5b.json b/models/qwen2.5-1.5b.json index d75c773..2e459ba 100644 --- a/models/qwen2.5-1.5b.json +++ b/models/qwen2.5-1.5b.json @@ -1,6 +1,10 @@ { "name": "Qwen2.5 1.5B", "author": "Alibaba", + "hidden_size": 1536, + "attention_heads": 12, + "hidden_layers": 28, + "kv_heads": 2, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-1.5b/resolve/main/Qwen2.5-1.5B-Instruct-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 3.09 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-1.5b/resolve/main/Qwen2.5-1.5B-Instruct-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.65 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-1.5b/resolve/main/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.986 } } } \ No newline at end of file diff --git a/models/qwen2.5-14b.json b/models/qwen2.5-14b.json index 6aebc86..f574f7c 100644 --- a/models/qwen2.5-14b.json +++ b/models/qwen2.5-14b.json @@ -1,6 +1,10 @@ { "name": "Qwen2.5 14B", "author": "Alibaba", + "hidden_size": 5120, + "attention_heads": 40, + "hidden_layers": 48, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/bartowski/Qwen2.5-14B-Instruct-GGUF/resolve/main/Qwen2.5-14B-Instruct-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 29.5 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/bartowski/Qwen2.5-14B-Instruct-GGUF/resolve/main/Qwen2.5-14B-Instruct-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 15.7 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/bartowski/Qwen2.5-14B-Instruct-GGUF/resolve/main/Qwen2.5-14B-Instruct-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.99 } } } \ No newline at end of file diff --git a/models/qwen2.5-3b.json b/models/qwen2.5-3b.json index c15c311..1ef0375 100644 --- a/models/qwen2.5-3b.json +++ b/models/qwen2.5-3b.json @@ -1,6 +1,10 @@ { "name": "Qwen2.5 3B", "author": "Alibaba", + "hidden_size": 2048, + "attention_heads": 16, + "hidden_layers": 36, + "kv_heads": 2, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-3b/resolve/main/Qwen2.5-3B-Instruct-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 6.18 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-3b/resolve/main/Qwen2.5-3B-Instruct-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 3.29 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-3b/resolve/main/Qwen2.5-3B-Instruct-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.93 } } } \ No newline at end of file diff --git a/models/qwen2.5-7b.json b/models/qwen2.5-7b.json index 3cd26ea..05a52f4 100644 --- a/models/qwen2.5-7b.json +++ b/models/qwen2.5-7b.json @@ -1,6 +1,10 @@ { "name": "Qwen2.5 7B", "author": "Alibaba", + "hidden_size": 3584, + "attention_heads": 28, + "hidden_layers": 28, + "kv_heads": 4, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,7 +12,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-7b/resolve/main/Qwen2.5-7B-Instruct-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 15.2 }, "8-bit Quantized": { "type": "8-bit Quantized", @@ -16,7 +21,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-7b/resolve/main/Qwen2.5-7B-Instruct-Q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.1 }, "4-bit Quantized": { "type": "4-bit Quantized", @@ -24,7 +30,8 @@ "downloadLink": "https://huggingface.co/kolosal/qwen2.5-7b/resolve/main/Qwen2.5-7B-Instruct-Q4_K_M.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.68 } } } \ No newline at end of file diff --git a/models/qwen2.5-coder-0.5b.json b/models/qwen2.5-coder-0.5b.json index 49fb086..e8221e9 100644 --- a/models/qwen2.5-coder-0.5b.json +++ b/models/qwen2.5-coder-0.5b.json @@ -1,30 +1,37 @@ { "name": "Qwen Coder 0.5B", "author": "Alibaba", + "hidden_size": 896, + "attention_heads": 14, + "hidden_layers": 24, + "kv_heads": 2, "variants": { "Full Precision": { "type": "Full Precision", "path": "models/qwen-coder-0.5b/fp16/qwen2.5-coder-0.5b-instruct-fp16.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct-GGUF/resolve/main/qwen2.5-coder-0.5b-instruct-fp16.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-0.5b/resolve/main/qwen2.5-coder-0.5b-instruct-fp16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.994 }, "8-bit Quantized": { "type": "8-bit Quantized", "path": "models/qwen-coder-0.5b/int8/qwen2.5-coder-0.5b-instruct-q8_0.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct-GGUF/resolve/main/qwen2.5-coder-0.5b-instruct-q8_0.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-0.5b/resolve/main/qwen2.5-coder-0.5b-instruct-q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.531 }, "4-bit Quantized": { "type": "4-bit Quantized", "path": "models/qwen-coder-0.5b/int4/qwen2.5-coder-0.5b-instruct-q4_k_m.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct-GGUF/resolve/main/qwen2.5-coder-0.5b-instruct-q4_k_m.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-0.5b/resolve/main/qwen2.5-coder-0.5b-instruct-q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.398 } } } \ No newline at end of file diff --git a/models/qwen2.5-coder-1.5b.json b/models/qwen2.5-coder-1.5b.json index 85c8c4f..f40720a 100644 --- a/models/qwen2.5-coder-1.5b.json +++ b/models/qwen2.5-coder-1.5b.json @@ -1,6 +1,10 @@ { "name": "Qwen Coder 1.5B", "author": "Alibaba", + "hidden_size": 1536, + "attention_heads": 12, + "hidden_layers": 28, + "kv_heads": 2, "variants": { "Full Precision": { "type": "Full Precision", @@ -8,23 +12,26 @@ "downloadLink": "https://huggingface.co/neopolita/qwen2.5-coder-1.5b-gguf/resolve/main/ggml-model-f16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 3.09 }, "8-bit Quantized": { "type": "8-bit Quantized", "path": "models/qwen-coder-1.5b/int8/qwen2.5-coder-0.5b-instruct-q8_0.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/resolve/main/qwen2.5-coder-1.5b-instruct-q8_0.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-1.5b/resolve/main/qwen2.5-coder-1.5b-instruct-q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.65 }, "4-bit Quantized": { "type": "4-bit Quantized", "path": "models/qwen-coder-1.5b/int4/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/resolve/main/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-1.5b/resolve/main/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 0.986 } } } \ No newline at end of file diff --git a/models/qwen2.5-coder-14b.json b/models/qwen2.5-coder-14b.json index 890d9dd..8d91445 100644 --- a/models/qwen2.5-coder-14b.json +++ b/models/qwen2.5-coder-14b.json @@ -1,30 +1,37 @@ { "name": "Qwen Coder 14B", "author": "Alibaba", + "hidden_size": 5120, + "attention_heads": 40, + "hidden_layers": 48, + "kv_heads": 8, "variants": { "Full Precision": { "type": "Full Precision", "path": "models/qwen-coder-14b/fp16/qwen2.5-coder-14b-instruct-fp16.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-14B-Instruct-GGUF/resolve/main/qwen2.5-coder-14b-instruct-fp16.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-14b/resolve/main/qwen2.5-coder-14b-instruct-fp16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 29.5 }, "8-bit Quantized": { "type": "8-bit Quantized", "path": "models/qwen-coder-14b/int8/qwen2.5-coder-14b-instruct-q8_0.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-14B-Instruct-GGUF/resolve/main/qwen2.5-coder-14b-instruct-q8_0.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-14b/resolve/main/qwen2.5-coder-14b-instruct-q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 15.7 }, "4-bit Quantized": { "type": "4-bit Quantized", "path": "models/qwen-coder-14b/int4/qwen2.5-coder-14b-instruct-q4_k_m.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-14B-Instruct-GGUF/resolve/main/qwen2.5-coder-14b-instruct-q4_k_m.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-14b/resolve/main/qwen2.5-coder-14b-instruct-q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.99 } } } \ No newline at end of file diff --git a/models/qwen2.5-coder-3b.json b/models/qwen2.5-coder-3b.json index 15f82be..8078526 100644 --- a/models/qwen2.5-coder-3b.json +++ b/models/qwen2.5-coder-3b.json @@ -1,30 +1,37 @@ { "name": "Qwen Coder 3B", "author": "Alibaba", + "hidden_size": 2048, + "attention_heads": 16, + "hidden_layers": 36, + "kv_heads": 2, "variants": { "Full Precision": { "type": "Full Precision", "path": "models/qwen-coder-3b/fp16/qwen2.5-coder-3b-instruct-fp16.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct-GGUF/resolve/main/qwen2.5-coder-3b-instruct-fp16.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-3b/resolve/main/qwen2.5-coder-3b-instruct-fp16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 6.18 }, "8-bit Quantized": { "type": "8-bit Quantized", "path": "models/qwen-coder-3b/int8/qwen2.5-coder-3b-instruct-q8_0.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct-GGUF/resolve/main/qwen2.5-coder-3b-instruct-q8_0.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-3b/resolve/main/qwen2.5-coder-3b-instruct-q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 3.29 }, "4-bit Quantized": { "type": "4-bit Quantized", "path": "models/qwen-coder-3b/int4/qwen2.5-coder-3b-instruct-q4_k_m.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct-GGUF/resolve/main/qwen2.5-coder-3b-instruct-q4_k_m.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-3b/resolve/main/qwen2.5-coder-3b-instruct-q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 1.93 } } } \ No newline at end of file diff --git a/models/qwen2.5-coder-7b.json b/models/qwen2.5-coder-7b.json index c98a7d5..4094363 100644 --- a/models/qwen2.5-coder-7b.json +++ b/models/qwen2.5-coder-7b.json @@ -1,30 +1,37 @@ { "name": "Qwen Coder 7B", "author": "Alibaba", + "hidden_size": 3584, + "attention_heads": 28, + "hidden_layers": 28, + "kv_heads": 4, "variants": { "Full Precision": { "type": "Full Precision", "path": "models/qwen-coder-7b/fp16/qwen2.5-coder-7b-instruct-fp16.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct-GGUF/resolve/main/qwen2.5-coder-7b-instruct-fp16.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-7b/resolve/main/qwen2.5-coder-7b-instruct-fp16.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 15.2 }, "8-bit Quantized": { "type": "8-bit Quantized", "path": "models/qwen-coder-7b/int8/qwen2.5-coder-3b-instruct-q8_0.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct-GGUF/resolve/main/qwen2.5-coder-3b-instruct-q8_0.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-7b/resolve/main/qwen2.5-coder-3b-instruct-q8_0.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 8.1 }, "4-bit Quantized": { "type": "4-bit Quantized", "path": "models/qwen-coder-7b/int4/qwen2.5-coder-7b-instruct-q4_k_m.gguf", - "downloadLink": "https://huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct-GGUF/resolve/main/qwen2.5-coder-7b-instruct-q4_k_m.gguf", + "downloadLink": "https://huggingface.co/kolosal/qwen2.5-coder-7b/resolve/main/qwen2.5-coder-7b-instruct-q4_k_m.gguf", "isDownloaded": false, "downloadProgress": 0.0, - "lastSelected": 0 + "lastSelected": 0, + "size": 4.68 } } } \ No newline at end of file diff --git a/server-test/python/openai_completion.py b/server-test/python/openai_completion.py new file mode 100644 index 0000000..e94b73b --- /dev/null +++ b/server-test/python/openai_completion.py @@ -0,0 +1,31 @@ +import openai +import os + +# Configure the client to use your local endpoint +client = openai.OpenAI( + base_url="http://localhost:8080/v1", + api_key="sk-dummy" # Using dummy API key as in the curl example +) + +print("Starting streaming request...\n") + +prompt = f"Halo adalah" + +# Make a streaming request using completions API instead of chat +stream = client.completions.create( + model="Qwen2.5 0.5B", + prompt=prompt, + stream=True, + max_tokens=32 +) + +# Process streaming response +print("Streaming response:") +full_response = "" +for chunk in stream: + if chunk.choices[0].text is not None: + content = chunk.choices[0].text + full_response += content + print(content, end="", flush=True) + +print("\n\nFull response:", full_response) diff --git a/server-test/python/openai_completion_ns.py b/server-test/python/openai_completion_ns.py new file mode 100644 index 0000000..5f65882 --- /dev/null +++ b/server-test/python/openai_completion_ns.py @@ -0,0 +1,27 @@ +import openai +import os + +# Configure the client to use your local endpoint +client = openai.OpenAI( + base_url="http://localhost:8080/v1", + api_key="sk-dummy" # Using dummy API key as in the curl example +) + +print("Starting non-streaming request...\n") + +# Format the messages into a single text prompt +system_message = "You are a helpful assistant." +user_message = "Why anything to the power of zero is 1?" +prompt = f"{system_message}\n\nUser: {user_message}\nAssistant:" + +# Make a non-streaming request using completions API +response = client.completions.create( + model="Qwen2.5 0.5B", + prompt=prompt, + max_tokens=32 +) + +# Process the response +full_response = response.choices[0].text +print("Response:") +print(full_response) diff --git a/server-test/python/openai_test.py b/server-test/python/openai_test.py index 1d4fcd4..73868d7 100644 --- a/server-test/python/openai_test.py +++ b/server-test/python/openai_test.py @@ -11,7 +11,7 @@ # Make a streaming request stream = client.chat.completions.create( - model="claude-3-opus-20240229", + model="Qwen2.5 0.5B", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Why anything to the power of zero is 1?"} diff --git a/server-test/python/openai_test_2.py b/server-test/python/openai_test_2.py index 40554f2..faef216 100644 --- a/server-test/python/openai_test_2.py +++ b/server-test/python/openai_test_2.py @@ -11,7 +11,7 @@ # Make a streaming request stream = client.chat.completions.create( - model="claude-3-opus-20240229", + model="Qwen Coder 0.5B", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"} diff --git a/source/main.cpp b/source/main.cpp index c0fc4ed..677f3ca 100644 --- a/source/main.cpp +++ b/source/main.cpp @@ -6,6 +6,7 @@ #include "ui/fonts.hpp" #include "ui/title_bar.hpp" #include "ui/tab_manager.hpp" +#include "ui/status_bar.hpp" #include "chat/chat_manager.hpp" #include "model/preset_manager.hpp" @@ -149,6 +150,9 @@ class Application tabManager->addTab(std::make_unique()); tabManager->addTab(std::make_unique()); + // Initialize the status bar + statusBar = std::make_unique(); + // Create and show the window window = WindowFactory::createWindow(); window->createWindow(Config::WINDOW_WIDTH, Config::WINDOW_HEIGHT, Config::WINDOW_TITLE, @@ -201,6 +205,9 @@ class Application // Render the currently active tab (chat tab in this example) tabManager->renderCurrentTab(); + // Render the status bar + statusBar->render(); + // Render ImGui ImGui::Render(); @@ -237,6 +244,7 @@ class Application std::unique_ptr cleanup; std::unique_ptr transitionManager; std::unique_ptr tabManager; + std::unique_ptr statusBar; int display_w; int display_h; };