Skip to content

Commit 489c8d9

Browse files
committed
feature: aaron work on leanvec ood
1 parent 597112f commit 489c8d9

File tree

3 files changed

+87
-57
lines changed

3 files changed

+87
-57
lines changed

examples/cpp/shared/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ project(svs_shared_library_example
1919
)
2020

2121
# Other AVX versions can be found at https://github.com/intel/ScalableVectorSearch/releases.
22-
set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/v0.0.9/svs-shared-library-0.0.9.tar.gz")
22+
set(SVS_URL "file:///raid0/dlin/my_work/private/libraries.ai.vector-search.svs-cpp-ood/lib.zip")
2323

2424
include(FetchContent)
2525
FetchContent_Declare(

examples/cpp/shared/example_vamana_with_compression.cpp

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,48 @@ int main() {
2828
// STEP 1: Compress Data with LeanVec, reducing dimensionality to leanvec_dim dimensions
2929
// and using 4 and 8 bits for primary and secondary levels respectively.
3030
//! [Compress data]
31-
const size_t num_threads = 4;
31+
const size_t num_threads = 64;
3232
size_t padding = 32;
33-
size_t leanvec_dim = 64;
33+
size_t leanvec_dim = 160;
3434
auto threadpool = svs::threads::as_threadpool(num_threads);
3535
auto loaded =
36-
svs::VectorDataLoader<float>(std::filesystem::path(SVS_DATA_DIR) / "data_f32.svs")
36+
//svs::VectorDataLoader<float, svs::Dynamic, svs::lib::Allocator<float>>("/export/data/ishwarsi/laion/laion_base_1M.fvecs")
37+
//svs::VectorDataLoader<float, svs::Dynamic, svs::lib::Allocator<float>>("/raid0/ishwarsi/datasets/open-images/oi_base_1M.fvecs")
38+
svs::VectorDataLoader<float, svs::Dynamic, svs::lib::Allocator<float>>("/raid0/ishwarsi/datasets/rqa/rqa_base_1M.fvecs")
3739
.load();
40+
auto learn_queries =
41+
svs::load_data<float, svs::Dynamic>("/raid0/ishwarsi/datasets/rqa/rqa_learn_query_10k_ood.fvecs");
42+
auto leanvec_matrix = svs::leanvec::compute_leanvec_matrices_ood<svs::Dynamic, svs::Dynamic>(loaded, learn_queries, svs::lib::MaybeStatic<svs::Dynamic>(leanvec_dim));
43+
//std::cerr << "data matrix: \n";
44+
//for(size_t i = 0; i < 100; ++i) {
45+
//for(size_t j = 0; j < leanvec_dim; ++j) {
46+
//std::cerr << *(leanvec_matrix.data_matrix_.data() + i * leanvec_dim + j) << "\n";
47+
//}
48+
//}
49+
//std::cerr << "query matrix: \n";
50+
//for(size_t i = 0; i < 100; ++i) {
51+
//for(size_t j = 0; j < leanvec_dim; ++j) {
52+
//std::cerr << *(leanvec_matrix.query_matrix_.data() + i * leanvec_dim + j) << "\n";
53+
//}
54+
//}
55+
//std::cerr << "data ood transform:\n";
56+
//for(size_t i = 0; i < 100; ++i) {
57+
//for(size_t k = 0; k < leanvec_dim; ++k) {
58+
//double tmp = 0.0;
59+
//for(size_t j = 0; j < loaded.dimensions(); ++j) {
60+
//tmp += (*(loaded.data() + i * loaded.dimensions() + j)) * (*(leanvec_matrix.data_matrix_.data() + j * leanvec_dim + k ));
61+
//}
62+
//std::cerr << tmp << "\n";
63+
//}
64+
//}
3865
auto data = svs::leanvec::LeanDataset<
3966
svs::leanvec::UsingLVQ<4>,
4067
svs::leanvec::UsingLVQ<8>,
4168
svs::Dynamic,
4269
svs::Dynamic>::
4370
reduce(
4471
loaded,
45-
std::nullopt,
72+
leanvec_matrix,
4673
threadpool,
4774
padding,
4875
svs::lib::MaybeStatic<svs::Dynamic>(leanvec_dim)
@@ -53,24 +80,27 @@ int main() {
5380
//! [Index Build]
5481
auto parameters = svs::index::vamana::VamanaBuildParameters{};
5582
svs::Vamana index = svs::Vamana::build<float>(
56-
parameters, data, svs::distance::DistanceL2(), num_threads
83+
parameters, data, svs::distance::DistanceIP(), num_threads
5784
);
85+
index.save("config", "graph", "data");
5886
//! [Index Build]
5987

6088
// STEP 3: Search the Index
6189
//! [Perform Queries]
62-
const size_t search_window_size = 50;
90+
const size_t search_window_size = 450;
6391
const size_t n_neighbors = 10;
6492
index.set_search_window_size(search_window_size);
6593

6694
auto queries =
67-
svs::load_data<float>(std::filesystem::path(SVS_DATA_DIR) / "queries_f32.fvecs");
95+
svs::load_data<float>("/raid0/ishwarsi/datasets/rqa/rqa_query_10k_ood.fvecs");
96+
//svs::load_data<float>("/raid0/ishwarsi/datasets/open-images/oi_queries_10k.fvecs");
6897
auto results = index.search(queries, n_neighbors);
6998
//! [Perform Queries]
7099

71100
//! [Recall]
72101
auto groundtruth = svs::load_data<int>(
73-
std::filesystem::path(SVS_DATA_DIR) / "groundtruth_euclidean.ivecs"
102+
//"/raid0/ishwarsi/datasets/open-images/oi_gtruth_1M.ivecs"
103+
"/raid0/ishwarsi/datasets/rqa/rqa_1M_gtruth_ood.ivecs"
74104
);
75105
double recall = svs::k_recall_at_n(groundtruth, results, n_neighbors, n_neighbors);
76106

@@ -79,23 +109,23 @@ int main() {
79109

80110
// STEP 4: Saving and reloading the index
81111
//! [Saving Loading]
82-
index.save("config", "graph", "data");
83-
index = svs::Vamana::assemble<float>(
84-
"config",
85-
svs::GraphLoader("graph"),
86-
svs::lib::load_from_disk<svs::leanvec::LeanDataset<
87-
svs::leanvec::UsingLVQ<4>,
88-
svs::leanvec::UsingLVQ<8>,
89-
svs::Dynamic,
90-
svs::Dynamic>>("data", padding),
91-
svs::distance::DistanceL2(),
92-
num_threads
93-
);
94-
//! [Saving Loading]
95-
index.set_search_window_size(search_window_size);
96-
recall = svs::k_recall_at_n(groundtruth, results, n_neighbors, n_neighbors);
112+
//index.save("config", "graph", "data");
113+
//index = svs::Vamana::assemble<float>(
114+
//"config",
115+
//svs::GraphLoader("graph"),
116+
//svs::lib::load_from_disk<svs::leanvec::LeanDataset<
117+
//svs::leanvec::UsingLVQ<4>,
118+
//svs::leanvec::UsingLVQ<8>,
119+
//svs::Dynamic,
120+
//svs::Dynamic>>("data", padding),
121+
//svs::distance::DistanceL2(),
122+
//num_threads
123+
//);
124+
////! [Saving Loading]
125+
//index.set_search_window_size(search_window_size);
126+
//recall = svs::k_recall_at_n(groundtruth, results, n_neighbors, n_neighbors);
97127

98-
fmt::print("Recall@{} after saving and reloading = {:.4f}\n", n_neighbors, recall);
128+
//fmt::print("Recall@{} after saving and reloading = {:.4f}\n", n_neighbors, recall);
99129

100130
return 0;
101131
}

include/svs/index/ivf/common.h

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -242,39 +242,39 @@ void compute_matmul(
242242
n // const int ldc
243243
);
244244
} else if constexpr (std::is_same_v<T, BFloat16>) {
245-
cblas_gemm_bf16bf16f32(
246-
CblasRowMajor, // CBLAS_LAYOUT layout
247-
CblasNoTrans, // CBLAS_TRANSPOSE TransA
248-
CblasTrans, // CBLAS_TRANSPOSE TransB
249-
m, // const int M
250-
n, // const int N
251-
k, // const int K
252-
1.0, // float alpha
253-
(const uint16_t*)data, // const *uint16_t A
254-
k, // const int lda
255-
(const uint16_t*)centroids, // const uint16_t* B
256-
k, // const int ldb
257-
0.0, // const float beta
258-
results, // float* c
259-
n // const int ldc
260-
);
245+
//cblas_gemm_bf16bf16f32(
246+
//CblasRowMajor, // CBLAS_LAYOUT layout
247+
//CblasNoTrans, // CBLAS_TRANSPOSE TransA
248+
//CblasTrans, // CBLAS_TRANSPOSE TransB
249+
//m, // const int M
250+
//n, // const int N
251+
//k, // const int K
252+
//1.0, // float alpha
253+
//(const uint16_t*)data, // const *uint16_t A
254+
//k, // const int lda
255+
//(const uint16_t*)centroids, // const uint16_t* B
256+
//k, // const int ldb
257+
//0.0, // const float beta
258+
//results, // float* c
259+
//n // const int ldc
260+
//);
261261
} else if constexpr (std::is_same_v<T, Float16>) {
262-
cblas_gemm_f16f16f32(
263-
CblasRowMajor, // CBLAS_LAYOUT layout
264-
CblasNoTrans, // CBLAS_TRANSPOSE TransA
265-
CblasTrans, // CBLAS_TRANSPOSE TransB
266-
m, // const int M
267-
n, // const int N
268-
k, // const int K
269-
1.0, // float alpha
270-
(const uint16_t*)data, // const *uint16_t A
271-
k, // const int lda
272-
(const uint16_t*)centroids, // const uint16_t* B
273-
k, // const int ldb
274-
0.0, // const float beta
275-
results, // float* c
276-
n // const int ldc
277-
);
262+
//cblas_gemm_f16f16f32(
263+
//CblasRowMajor, // CBLAS_LAYOUT layout
264+
//CblasNoTrans, // CBLAS_TRANSPOSE TransA
265+
//CblasTrans, // CBLAS_TRANSPOSE TransB
266+
//m, // const int M
267+
//n, // const int N
268+
//k, // const int K
269+
//1.0, // float alpha
270+
//(const uint16_t*)data, // const *uint16_t A
271+
//k, // const int lda
272+
//(const uint16_t*)centroids, // const uint16_t* B
273+
//k, // const int ldb
274+
//0.0, // const float beta
275+
//results, // float* c
276+
//n // const int ldc
277+
//);
278278
} else {
279279
throw ANNEXCEPTION("GEMM type not supported!");
280280
}

0 commit comments

Comments
 (0)