Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/dalotia.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TensorFile *make_tensor_file(const std::string &filename) {
::tolower);

// select the file implementation
if (extension == "safetensors") {
if (extension == "safetensors" || extension == "st") {
#ifdef DALOTIA_WITH_SAFETENSORS_CPP
return new SafetensorsFile(filename);
#else // DALOTIA_WITH_SAFETENSORS_CPP
Expand All @@ -69,13 +69,37 @@ TensorFile *make_tensor_file(const std::string &filename) {
return nullptr;
}


// factory function for the file, selected by file extension and
// available implementations
TensorFile *load_tensor_file_from_memory(const void * const address, size_t num_bytes, const std::string &format) {
auto& extension = format;
// select the file implementation
if (extension == "safetensors") {
#ifdef DALOTIA_WITH_SAFETENSORS_CPP
return new SafetensorsFile(address, num_bytes);
#else // DALOTIA_WITH_SAFETENSORS_CPP
throw std::runtime_error("Safetensors support not enabled");
#endif // DALOTIA_WITH_SAFETENSORS_CPP
} else {
throw std::runtime_error("Unsupported memory format: ." + extension);
}
return nullptr;
}

} // namespace dalotia

DalotiaTensorFile *dalotia_open_file(const char *filename) {
return reinterpret_cast<DalotiaTensorFile *>(
dalotia::make_tensor_file(std::string(filename)));
}

DalotiaTensorFile *dalotia_load_file_from_memory(const void * const address, size_t num_bytes, const char *format) {
return reinterpret_cast<DalotiaTensorFile *>(
dalotia::load_tensor_file_from_memory(address, num_bytes, std::string(format)));
}


void dalotia_close_file(DalotiaTensorFile *file) {
delete reinterpret_cast<dalotia::TensorFile *>(file);
}
Expand Down
8 changes: 8 additions & 0 deletions src/dalotia.f90
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ type(C_ptr) function dalotia_open_file_c(file_name) bind(C,name="dalotia_open_fi
character(kind=C_char), dimension(*), intent(in):: file_name
end function dalotia_open_file_c

type(C_ptr) function dalotia_load_file_from_memory_c(address, num_bytes, file_format) bind(C,name="dalotia_load_file_from_memory")
use, intrinsic::ISO_C_BINDING, only: C_ptr, C_char, C_size_t
implicit none
type(C_ptr), intent(in), value :: address
integer(C_size_t) :: num_bytes
character(kind=C_char), dimension(*), intent(in):: file_format
end function dalotia_load_file_from_memory_c

subroutine dalotia_close_file(dalotia_file_pointer) bind(C,name="dalotia_close_file")
use, intrinsic::ISO_C_BINDING, only: C_ptr
implicit none
Expand Down
2 changes: 2 additions & 0 deletions src/dalotia.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ typedef struct DalotiaTensorFile DalotiaTensorFile;

EXTERNC DalotiaTensorFile *dalotia_open_file(const char *filename);

EXTERNC DalotiaTensorFile *dalotia_load_file_from_memory(const void *address, size_t num_bytes, const char *format);

EXTERNC void dalotia_close_file(DalotiaTensorFile *file);

EXTERNC int dalotia_sizeof_weight_format(dalotia_WeightFormat format);
Expand Down
2 changes: 2 additions & 0 deletions src/dalotia.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace dalotia {
// available implementations
[[nodiscard]] TensorFile *make_tensor_file(const std::string & filename);

[[nodiscard]] TensorFile *load_tensor_file_from_memory(const void * const address, size_t num_bytes, const char *format);

// C++17 version -> will not compile on Fugaku...
// -- pmr vector types can accept different allocators
//? more memory interface than that? detect if CUDA device pointer through
Expand Down
21 changes: 21 additions & 0 deletions src/dalotia_safetensors_file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ SafetensorsFile::SafetensorsFile(const std::string &filename) : TensorFile(filen
#endif // NDEBUG
}

SafetensorsFile::SafetensorsFile(const void * const address, size_t num_bytes) : TensorFile("") {
// as far as I can tell, safetensors are saved in C order
std::string warn, err;
bool ret = safetensors::mmap_from_memory(static_cast<const uint8_t*>(address), num_bytes, "", &st_, &warn, &err);
if (warn.size() > 0) {
std::cout << "safetensors-cpp WARN: " << warn << "\n";
}
if (ret == false) {
std::cerr << " ERR: " << err << "\n";
throw std::runtime_error("Could not load safetensors from address");
}
#ifndef NDEBUG
// Check if data_offsets are valid
if (!safetensors::validate_data_offsets(st_, err)) {
std::cerr << "Invalid data_offsets\n";
std::cerr << err << "\n";
throw std::runtime_error("Invalid safetensors address");
}
#endif // NDEBUG
}

SafetensorsFile::~SafetensorsFile() {
if (st_.st_file != nullptr) {
// delete st_.st_file;
Expand Down
2 changes: 2 additions & 0 deletions src/dalotia_safetensors_file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class SafetensorsFile : public TensorFile {
public:
explicit SafetensorsFile(const std::string &filename);

SafetensorsFile(const void * const address, size_t num_bytes);

~SafetensorsFile() override;

const std::vector<std::string> &get_tensor_names() const override;
Expand Down