Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 69 additions & 25 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <filesystem>
#include <fstream>
#include <future>
#include <map>
#include <mutex>
#include <regex>
#include <string>
#include <thread>
Expand Down Expand Up @@ -469,36 +471,79 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &

#elif defined(LLAMA_USE_HTTPLIB)

static bool is_output_a_tty() {
class ProgressBar {
static inline std::mutex mutex;
static inline std::map<const ProgressBar *, int> lines;
static inline int max_line = 0;

static void cleanup(const ProgressBar * line) {
lines.erase(line);
if (lines.empty()) {
max_line = 0;
}
}

static bool is_output_a_tty() {
#if defined(_WIN32)
return _isatty(_fileno(stdout));
return _isatty(_fileno(stdout));
#else
return isatty(1);
return isatty(1);
#endif
}
}

static void print_progress(size_t current, size_t total) {
if (!is_output_a_tty()) {
return;
public:
ProgressBar() = default;

~ProgressBar() {
std::lock_guard<std::mutex> lock(mutex);
cleanup(this);
}

if (!total) {
return;
void update(size_t current, size_t total) {
if (!is_output_a_tty()) {
return;
}

if (!total) {
return;
}

std::lock_guard<std::mutex> lock(mutex);

if (lines.find(this) == lines.end()) {
lines[this] = max_line++;
std::cout << "\n";
}
int lines_up = max_line - lines[this];

size_t width = 50;
size_t pct = (100 * current) / total;
size_t pos = (width * current) / total;

std::cout << "\033[s";

if (lines_up > 0) {
std::cout << "\033[" << lines_up << "A";
}
std::cout << "\033[2K\r["
<< std::string(pos, '=')
<< (pos < width ? ">" : "")
<< std::string(width - pos, ' ')
<< "] " << std::setw(3) << pct << "% ("
<< current / (1024 * 1024) << " MB / "
<< total / (1024 * 1024) << " MB) "
<< "\033[u";

std::cout.flush();

if (current == total) {
cleanup(this);
}
}

size_t width = 50;
size_t pct = (100 * current) / total;
size_t pos = (width * current) / total;

std::cout << "["
<< std::string(pos, '=')
<< (pos < width ? ">" : "")
<< std::string(width - pos, ' ')
<< "] " << std::setw(3) << pct << "% ("
<< current / (1024 * 1024) << " MB / "
<< total / (1024 * 1024) << " MB)\r";
std::cout.flush();
}
ProgressBar(const ProgressBar &) = delete;
ProgressBar & operator=(const ProgressBar &) = delete;
};

static bool common_pull_file(httplib::Client & cli,
const std::string & resolve_path,
Expand All @@ -520,6 +565,7 @@ static bool common_pull_file(httplib::Client & cli,
const char * func = __func__; // avoid __func__ inside a lambda
size_t downloaded = existing_size;
size_t progress_step = 0;
ProgressBar bar;

auto res = cli.Get(resolve_path, headers,
[&](const httplib::Response &response) {
Expand Down Expand Up @@ -551,16 +597,14 @@ static bool common_pull_file(httplib::Client & cli,
progress_step += len;

if (progress_step >= total_size / 1000 || downloaded == total_size) {
print_progress(downloaded, total_size);
bar.update(downloaded, total_size);
progress_step = 0;
}
return true;
},
nullptr
);

std::cout << "\n";

if (!res) {
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
return false;
Expand Down
Loading