diff --git a/common/download.cpp b/common/download.cpp index 099eaa059b..6a23c7d071 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -469,36 +471,79 @@ std::pair> common_remote_get_content(const std::string & #elif defined(LLAMA_USE_HTTPLIB) -static bool is_output_a_tty() { +class ProgressBar { + static inline std::mutex mutex; + static inline std::map lines; + static inline int max_line = 0; + + static void cleanup(const ProgressBar * line) { + lines.erase(line); + if (lines.empty()) { + max_line = 0; + } + } + + static bool is_output_a_tty() { #if defined(_WIN32) - return _isatty(_fileno(stdout)); + return _isatty(_fileno(stdout)); #else - return isatty(1); + return isatty(1); #endif -} + } -static void print_progress(size_t current, size_t total) { - if (!is_output_a_tty()) { - return; +public: + ProgressBar() = default; + + ~ProgressBar() { + std::lock_guard lock(mutex); + cleanup(this); } - if (!total) { - return; + void update(size_t current, size_t total) { + if (!is_output_a_tty()) { + return; + } + + if (!total) { + return; + } + + std::lock_guard lock(mutex); + + if (lines.find(this) == lines.end()) { + lines[this] = max_line++; + std::cout << "\n"; + } + int lines_up = max_line - lines[this]; + + size_t width = 50; + size_t pct = (100 * current) / total; + size_t pos = (width * current) / total; + + std::cout << "\033[s"; + + if (lines_up > 0) { + std::cout << "\033[" << lines_up << "A"; + } + std::cout << "\033[2K\r[" + << std::string(pos, '=') + << (pos < width ? ">" : "") + << std::string(width - pos, ' ') + << "] " << std::setw(3) << pct << "% (" + << current / (1024 * 1024) << " MB / " + << total / (1024 * 1024) << " MB) " + << "\033[u"; + + std::cout.flush(); + + if (current == total) { + cleanup(this); + } } - size_t width = 50; - size_t pct = (100 * current) / total; - size_t pos = (width * current) / total; - - std::cout << "[" - << std::string(pos, '=') - << (pos < width ? ">" : "") - << std::string(width - pos, ' ') - << "] " << std::setw(3) << pct << "% (" - << current / (1024 * 1024) << " MB / " - << total / (1024 * 1024) << " MB)\r"; - std::cout.flush(); -} + ProgressBar(const ProgressBar &) = delete; + ProgressBar & operator=(const ProgressBar &) = delete; +}; static bool common_pull_file(httplib::Client & cli, const std::string & resolve_path, @@ -520,6 +565,7 @@ static bool common_pull_file(httplib::Client & cli, const char * func = __func__; // avoid __func__ inside a lambda size_t downloaded = existing_size; size_t progress_step = 0; + ProgressBar bar; auto res = cli.Get(resolve_path, headers, [&](const httplib::Response &response) { @@ -551,7 +597,7 @@ static bool common_pull_file(httplib::Client & cli, progress_step += len; if (progress_step >= total_size / 1000 || downloaded == total_size) { - print_progress(downloaded, total_size); + bar.update(downloaded, total_size); progress_step = 0; } return true; @@ -559,8 +605,6 @@ static bool common_pull_file(httplib::Client & cli, nullptr ); - std::cout << "\n"; - if (!res) { LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1); return false;