Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions include/npy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,43 @@ inline npy_data<Scalar> read_npy(std::istream &in) {
return data;
}

template <typename Scalar, typename It>
inline npy_data_ptr<Scalar> read_npy(std::istream &in, It dst_buf, size_t count) {
static_assert(std::is_same<Scalar, typename std::iterator_traits<It>::value_type>::value,
"Scalar and It::value_type must be identical");

std::string header_s = read_header(in);

// parse header
header_t header = parse_header(header_s);

// check if the typestring matches the given one
const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));

if (header.dtype.tie() != dtype.tie()) {
throw std::runtime_error("formatting error: typestrings not matching");
}

// compute the data size based on the shape
auto size = static_cast<size_t>(comp_size(header.shape));

if (size > count) {
throw std::runtime_error("dst_buf too small to hold file contents");
}

npy_data_ptr<Scalar> data;

data.shape = header.shape;
data.fortran_order = header.fortran_order;

data.data_ptr = &(*dst_buf);

// read the data
in.read(reinterpret_cast<char *>(data.data_ptr), sizeof(Scalar) * size);

return data;
}

template <typename Scalar>
inline npy_data<Scalar> read_npy(const std::string &filename) {
std::ifstream stream(filename, std::ifstream::binary);
Expand All @@ -513,6 +550,16 @@ inline npy_data<Scalar> read_npy(const std::string &filename) {
return read_npy<Scalar>(stream);
}

template <typename Scalar, typename It>
inline npy_data_ptr<Scalar> read_npy(const std::string &filename, It dst_buf, size_t count) {
std::ifstream stream(filename, std::ifstream::binary);
if (!stream) {
throw std::runtime_error("io error: failed to open a file.");
}

return read_npy<Scalar>(stream, dst_buf, count);
}

template <typename Scalar>
inline void write_npy(std::ostream &out, const npy_data<Scalar> &data) {
// static_assert(has_typestring<Scalar>::value, "scalar type not
Expand Down