Skip to content

Commit 09717b6

Browse files
committed
common : add minimalist multi-thread progress bar
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
1 parent 7f8ef50 commit 09717b6

File tree

1 file changed

+71
-25
lines changed

1 file changed

+71
-25
lines changed

common/download.cpp

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <filesystem>
1313
#include <fstream>
1414
#include <future>
15+
#include <map>
16+
#include <mutex>
1517
#include <regex>
1618
#include <string>
1719
#include <thread>
@@ -469,36 +471,81 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
469471

470472
#elif defined(LLAMA_USE_HTTPLIB)
471473

472-
static bool is_output_a_tty() {
474+
class ProgressBar {
475+
static inline std::mutex mutex;
476+
static inline std::map<const ProgressBar *, int> lines;
477+
static inline int max_line = 0;
478+
479+
static void cleanup(const ProgressBar * line) {
480+
lines.erase(line);
481+
if (lines.empty()) {
482+
max_line = 0;
483+
}
484+
}
485+
486+
static bool is_output_a_tty() {
473487
#if defined(_WIN32)
474-
return _isatty(_fileno(stdout));
488+
return _isatty(_fileno(stdout));
475489
#else
476-
return isatty(1);
490+
return isatty(1);
477491
#endif
478-
}
492+
}
479493

480-
static void print_progress(size_t current, size_t total) {
481-
if (!is_output_a_tty()) {
482-
return;
494+
public:
495+
ProgressBar() = default;
496+
497+
~ProgressBar() {
498+
std::lock_guard<std::mutex> lock(mutex);
499+
cleanup(this);
483500
}
484501

485-
if (!total) {
486-
return;
502+
void update(size_t current, size_t total) {
503+
(void)this; // avoid the static warning
504+
505+
if (!is_output_a_tty()) {
506+
return;
507+
}
508+
509+
if (!total) {
510+
return;
511+
}
512+
513+
std::lock_guard<std::mutex> lock(mutex);
514+
515+
if (lines.find(this) == lines.end()) {
516+
lines[this] = max_line++;
517+
std::cout << "\n";
518+
}
519+
int lines_up = max_line - lines[this];
520+
521+
size_t width = 50;
522+
size_t pct = (100 * current) / total;
523+
size_t pos = (width * current) / total;
524+
525+
std::cout << "\033[s";
526+
527+
if (lines_up > 0) {
528+
std::cout << "\033[" << lines_up << "A";
529+
}
530+
std::cout << "\033[2K\r["
531+
<< std::string(pos, '=')
532+
<< (pos < width ? ">" : "")
533+
<< std::string(width - pos, ' ')
534+
<< "] " << std::setw(3) << pct << "% ("
535+
<< current / (1024 * 1024) << " MB / "
536+
<< total / (1024 * 1024) << " MB) "
537+
<< "\033[u";
538+
539+
std::cout.flush();
540+
541+
if (current == total) {
542+
cleanup(this);
543+
}
487544
}
488545

489-
size_t width = 50;
490-
size_t pct = (100 * current) / total;
491-
size_t pos = (width * current) / total;
492-
493-
std::cout << "["
494-
<< std::string(pos, '=')
495-
<< (pos < width ? ">" : "")
496-
<< std::string(width - pos, ' ')
497-
<< "] " << std::setw(3) << pct << "% ("
498-
<< current / (1024 * 1024) << " MB / "
499-
<< total / (1024 * 1024) << " MB)\r";
500-
std::cout.flush();
501-
}
546+
ProgressBar(const ProgressBar &) = delete;
547+
ProgressBar & operator=(const ProgressBar &) = delete;
548+
};
502549

503550
static bool common_pull_file(httplib::Client & cli,
504551
const std::string & resolve_path,
@@ -520,6 +567,7 @@ static bool common_pull_file(httplib::Client & cli,
520567
const char * func = __func__; // avoid __func__ inside a lambda
521568
size_t downloaded = existing_size;
522569
size_t progress_step = 0;
570+
ProgressBar bar;
523571

524572
auto res = cli.Get(resolve_path, headers,
525573
[&](const httplib::Response &response) {
@@ -551,16 +599,14 @@ static bool common_pull_file(httplib::Client & cli,
551599
progress_step += len;
552600

553601
if (progress_step >= total_size / 1000 || downloaded == total_size) {
554-
print_progress(downloaded, total_size);
602+
bar.update(downloaded, total_size);
555603
progress_step = 0;
556604
}
557605
return true;
558606
},
559607
nullptr
560608
);
561609

562-
std::cout << "\n";
563-
564610
if (!res) {
565611
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
566612
return false;

0 commit comments

Comments
 (0)