Skip to content

Commit 2935963

Browse files
committed
Add Go OCI library integration with go-containerregistry
So we can pull from any OCI registry, add authentication, etc. Add docker-style progress bars and resumable downloads to OCI pulls Update documentation with progress bars and resumable downloads info Make OCI Go build optional and skip editorconfig for oci-go Add Go version check before building OCI library Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent ee09828 commit 2935963

File tree

11 files changed

+880
-101
lines changed

11 files changed

+880
-101
lines changed

.ecrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"Exclude": ["^\\.gitmodules$", "stb_image\\.h"],
2+
"Exclude": ["^\\.gitmodules$", "stb_image\\.h", "oci-go/"],
33
"Disable": {
44
"IndentSize": true
55
}

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
*.swp
2222
*.tmp
2323

24+
# OCI Go generated files
25+
oci-go/liboci.h
26+
2427
# IDE / OS
2528

2629
.cache/

common/CMakeLists.txt

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,75 @@ if (BUILD_SHARED_LIBS)
7777
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
7878
endif()
7979

80-
set(LLAMA_COMMON_EXTRA_LIBS build_info)
80+
# Build OCI Go library
81+
find_program(GO_EXECUTABLE go)
82+
if (GO_EXECUTABLE)
83+
# Check Go version - we need at least 1.21 for toolchain directive support
84+
set(GO_VERSION_RESULT 1)
85+
execute_process(
86+
COMMAND ${GO_EXECUTABLE} version
87+
OUTPUT_VARIABLE GO_VERSION_OUTPUT
88+
OUTPUT_STRIP_TRAILING_WHITESPACE
89+
RESULT_VARIABLE GO_VERSION_RESULT
90+
)
91+
92+
if (GO_VERSION_RESULT EQUAL 0)
93+
# Extract version number from "go version go1.X.Y ..." or "go version go1.X ..."
94+
string(REGEX MATCH "go([0-9]+)\\.([0-9]+)" GO_VERSION_MATCH "${GO_VERSION_OUTPUT}")
95+
if (GO_VERSION_MATCH)
96+
set(GO_VERSION_MAJOR ${CMAKE_MATCH_1})
97+
set(GO_VERSION_MINOR ${CMAKE_MATCH_2})
98+
99+
if (GO_VERSION_MAJOR LESS 1 OR (GO_VERSION_MAJOR EQUAL 1 AND GO_VERSION_MINOR LESS 21))
100+
message(WARNING "Go version ${GO_VERSION_MAJOR}.${GO_VERSION_MINOR} is too old. OCI functionality requires Go 1.21 or later. OCI functionality will not be available.")
101+
set(GO_VERSION_OK FALSE)
102+
else()
103+
set(GO_VERSION_OK TRUE)
104+
endif()
105+
else()
106+
message(WARNING "Unable to parse Go version from: ${GO_VERSION_OUTPUT}. OCI functionality will not be available.")
107+
set(GO_VERSION_OK FALSE)
108+
endif()
109+
else()
110+
message(WARNING "Failed to get Go version. OCI functionality will not be available.")
111+
set(GO_VERSION_OK FALSE)
112+
endif()
113+
endif()
114+
115+
if (GO_EXECUTABLE AND GO_VERSION_OK)
116+
set(OCI_GO_DIR ${CMAKE_SOURCE_DIR}/oci-go)
117+
set(OCI_LIB ${OCI_GO_DIR}/liboci.a)
118+
set(OCI_HEADER ${OCI_GO_DIR}/liboci.h)
119+
120+
add_custom_command(
121+
OUTPUT ${OCI_LIB} ${OCI_HEADER}
122+
COMMAND ${GO_EXECUTABLE} build -buildmode=c-archive -o ${OCI_LIB} ${OCI_GO_DIR}/oci.go
123+
WORKING_DIRECTORY ${OCI_GO_DIR}
124+
DEPENDS ${OCI_GO_DIR}/oci.go ${OCI_GO_DIR}/go.mod
125+
COMMENT "Building OCI Go library"
126+
)
127+
128+
add_custom_target(oci_go_lib DEPENDS ${OCI_LIB} ${OCI_HEADER})
129+
add_dependencies(${TARGET} oci_go_lib)
130+
131+
target_include_directories(${TARGET} PRIVATE ${OCI_GO_DIR})
132+
target_sources(${TARGET} PRIVATE oci.cpp oci.h)
133+
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_OCI)
134+
set(LLAMA_COMMON_EXTRA_LIBS build_info ${OCI_LIB})
135+
136+
# On macOS, the Go runtime requires CoreFoundation and Security frameworks
137+
if (APPLE)
138+
find_library(OCI_CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
139+
find_library(OCI_SECURITY_FRAMEWORK Security REQUIRED)
140+
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${OCI_CORE_FOUNDATION_FRAMEWORK} ${OCI_SECURITY_FRAMEWORK})
141+
endif()
142+
else()
143+
if (NOT GO_EXECUTABLE)
144+
message(WARNING "Go compiler not found. OCI functionality will not be available.")
145+
endif()
146+
set(LLAMA_COMMON_EXTRA_LIBS build_info)
147+
endif()
148+
81149

82150
# Use curl to download model url
83151
if (LLAMA_CURL)

common/arg.cpp

Lines changed: 29 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#include "gguf.h" // for reading GGUF splits
66
#include "json-schema-to-grammar.h"
77
#include "log.h"
8+
#ifdef LLAMA_USE_OCI
9+
#include "oci.h"
10+
#endif
811
#include "sampling.h"
912

1013
// fix problem with std::min and std::max
@@ -1043,119 +1046,42 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
10431046
// Docker registry functions
10441047
//
10451048

1046-
static std::string common_docker_get_token(const std::string & repo) {
1047-
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
1048-
1049-
common_remote_params params;
1050-
auto res = common_remote_get_content(url, params);
1051-
1052-
if (res.first != 200) {
1053-
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
1054-
}
1055-
1056-
std::string response_str(res.second.begin(), res.second.end());
1057-
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
1058-
1059-
if (!response.contains("token")) {
1060-
throw std::runtime_error("Docker registry token response missing 'token' field");
1061-
}
1062-
1063-
return response["token"].get<std::string>();
1064-
}
1065-
1049+
#ifdef LLAMA_USE_OCI
10661050
static std::string common_docker_resolve_model(const std::string & docker) {
1067-
// Parse ai/smollm2:135M-Q4_0
1068-
size_t colon_pos = docker.find(':');
1069-
std::string repo, tag;
1070-
if (colon_pos != std::string::npos) {
1071-
repo = docker.substr(0, colon_pos);
1072-
tag = docker.substr(colon_pos + 1);
1073-
} else {
1074-
repo = docker;
1075-
tag = "latest";
1076-
}
1051+
// Parse image reference (e.g., ai/smollm2:135M-Q4_0)
1052+
std::string image_ref = docker;
10771053

1078-
// ai/ is the default
1079-
size_t slash_pos = docker.find('/');
1054+
// ai/ is the default namespace for Docker Hub
1055+
size_t slash_pos = docker.find('/');
10801056
if (slash_pos == std::string::npos) {
1081-
repo.insert(0, "ai/");
1057+
image_ref = "ai/" + docker;
10821058
}
10831059

1084-
LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
1085-
try {
1086-
// --- helper: digest validation ---
1087-
auto validate_oci_digest = [](const std::string & digest) -> std::string {
1088-
// Expected: algo:hex ; start with sha256 (64 hex chars)
1089-
// You can extend this map if supporting other algorithms in future.
1090-
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
1091-
std::smatch m;
1092-
if (!std::regex_match(digest, m, re)) {
1093-
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
1094-
}
1095-
// normalize hex to lowercase
1096-
std::string normalized = digest;
1097-
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
1098-
return std::tolower(c);
1099-
});
1100-
return normalized;
1101-
};
1102-
1103-
std::string token = common_docker_get_token(repo); // Get authentication token
1104-
1105-
// Get manifest
1106-
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
1107-
std::string manifest_url = url_prefix + "/manifests/" + tag;
1108-
common_remote_params manifest_params;
1109-
manifest_params.headers.push_back("Authorization: Bearer " + token);
1110-
manifest_params.headers.push_back(
1111-
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
1112-
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
1113-
if (manifest_res.first != 200) {
1114-
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
1115-
}
1116-
1117-
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
1118-
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
1119-
std::string gguf_digest; // Find the GGUF layer
1120-
if (manifest.contains("layers")) {
1121-
for (const auto & layer : manifest["layers"]) {
1122-
if (layer.contains("mediaType")) {
1123-
std::string media_type = layer["mediaType"].get<std::string>();
1124-
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
1125-
media_type.find("gguf") != std::string::npos) {
1126-
gguf_digest = layer["digest"].get<std::string>();
1127-
break;
1128-
}
1129-
}
1130-
}
1131-
}
1132-
1133-
if (gguf_digest.empty()) {
1134-
throw std::runtime_error("No GGUF layer found in Docker manifest");
1135-
}
1060+
// Add registry prefix if not present
1061+
if (image_ref.find("registry-1.docker.io/") != 0 && image_ref.find("docker.io/") != 0 &&
1062+
image_ref.find("index.docker.io/") != 0) {
1063+
// For Docker Hub images without explicit registry
1064+
image_ref = "index.docker.io/" + image_ref;
1065+
}
11361066

1137-
// Validate & normalize digest
1138-
gguf_digest = validate_oci_digest(gguf_digest);
1139-
LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
1067+
try {
1068+
// Get cache directory
1069+
std::string cache_dir = fs_get_cache_directory();
11401070

1141-
// Prepare local filename
1142-
std::string model_filename = repo;
1143-
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
1144-
model_filename += "_" + tag + ".gguf";
1145-
std::string local_path = fs_get_cache_file(model_filename);
1071+
// Call the Go OCI library
1072+
auto result = oci_pull_model(image_ref, cache_dir);
11461073

1147-
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
1148-
if (!common_download_file_single(blob_url, local_path, token, false)) {
1149-
throw std::runtime_error("Failed to download Docker Model");
1074+
if (!result.success()) {
1075+
throw std::runtime_error("OCI pull failed: " + result.error_message);
11501076
}
11511077

1152-
LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
1153-
return local_path;
1078+
return result.local_path;
11541079
} catch (const std::exception & e) {
1155-
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
1080+
LOG_ERR("%s: OCI model download failed: %s\n", __func__, e.what());
11561081
throw;
11571082
}
11581083
}
1084+
#endif // LLAMA_USE_OCI
11591085

11601086
//
11611087
// utils
@@ -1208,7 +1134,11 @@ static handle_model_result common_params_handle_model(
12081134
// handle pre-fill default model path and url based on hf_repo and hf_file
12091135
{
12101136
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
1137+
#ifdef LLAMA_USE_OCI
12111138
model.path = common_docker_resolve_model(model.docker_repo);
1139+
#else
1140+
LOG_ERR("Need to build with go compiler and LLAMA_USE_OCI\n");
1141+
#endif
12121142
} else if (!model.hf_repo.empty()) {
12131143
// short-hand to avoid specifying --hf-file -> default it to --model
12141144
if (model.hf_file.empty()) {

common/oci.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#ifdef LLAMA_USE_OCI
2+
3+
#include "oci.h"
4+
5+
#include "log.h"
6+
7+
#include <nlohmann/json.hpp>
8+
9+
// Include the Go-generated header
10+
#include "../oci-go/liboci.h"
11+
12+
using json = nlohmann::ordered_json;
13+
14+
oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir) {
15+
oci_pull_result result;
16+
result.error_code = 0;
17+
18+
// Call the Go function
19+
char * json_result = PullOCIModel(const_cast<char *>(imageRef.c_str()), const_cast<char *>(cacheDir.c_str()));
20+
21+
if (json_result == nullptr) {
22+
result.error_code = 1;
23+
result.error_message = "Failed to call OCI pull function";
24+
return result;
25+
}
26+
27+
try {
28+
// Parse the JSON result
29+
std::string json_str(json_result);
30+
auto j = json::parse(json_str);
31+
32+
if (j.contains("LocalPath")) {
33+
result.local_path = j["LocalPath"].get<std::string>();
34+
}
35+
if (j.contains("Digest")) {
36+
result.digest = j["Digest"].get<std::string>();
37+
}
38+
if (j.contains("Error") && !j["Error"].is_null()) {
39+
auto err = j["Error"];
40+
if (err.contains("Code")) {
41+
result.error_code = err["Code"].get<int>();
42+
}
43+
if (err.contains("Message")) {
44+
result.error_message = err["Message"].get<std::string>();
45+
}
46+
}
47+
} catch (const std::exception & e) {
48+
result.error_code = 1;
49+
result.error_message = std::string("Failed to parse result: ") + e.what();
50+
}
51+
52+
// Free the Go-allocated string
53+
FreeString(json_result);
54+
55+
return result;
56+
}
57+
58+
#endif // LLAMA_USE_OCI

common/oci.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
#ifdef LLAMA_USE_OCI
4+
5+
#include <string>
6+
7+
// Structure to hold OCI pull results
8+
struct oci_pull_result {
9+
std::string local_path;
10+
std::string digest;
11+
int error_code;
12+
std::string error_message;
13+
14+
bool success() const { return error_code == 0; }
15+
};
16+
17+
// Pull a model from an OCI registry
18+
// imageRef: full image reference (e.g., "ai/smollm2:135M-Q4_0", "registry.io/user/model:tag")
19+
// cacheDir: directory to cache downloaded models
20+
oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir);
21+
22+
#endif // LLAMA_USE_OCI

0 commit comments

Comments
 (0)