diff --git a/misc/psi/io/io.h b/misc/psi/io/io.h index abfed3382..2022ac90c 100644 --- a/misc/psi/io/io.h +++ b/misc/psi/io/io.h @@ -37,25 +37,21 @@ using sharedContext = std::shared_ptr; using Ptxt = helib::Ptxt; -template -helib::Database readDbFromFile(const std::string& databaseFilePath, - const sharedContext& contextp, - const helib::PubKey& pk) +template +helib::Database readDbFromStream(const StreamDispenser& databaseStreamDispenser, + const sharedContext& contextp, + const helib::PubKey& pk) { - // Read in TXT file header - std::ifstream databaseFile(databaseFilePath); - if (!databaseFile.is_open()) { - throw std::runtime_error("Could not open file '" + databaseFilePath + "'."); - } - + auto databaseStreamPtr = databaseStreamDispenser.get(); TXT zero_txt(pk); // This is only needed for TXT = Ctxt std::optional> reader; long nrow, ncol; if constexpr (std::is_same_v) { - std::tie(nrow, ncol) = parseDimsHeader(readline(databaseFile)); + std::tie(nrow, ncol) = parseDimsHeader(readline(*databaseStreamPtr.get())); } else { - reader.emplace(Reader(databaseFilePath, zero_txt)); + // TODO + reader.emplace(databaseStreamPtr, zero_txt); nrow = reader.value().getTOC().getRows(); ncol = reader.value().getTOC().getCols(); } @@ -66,7 +62,7 @@ helib::Database readDbFromFile(const std::string& databaseFilePath, // Read in ptxts std::vector ptxt_strings(nrow * ncol); for (auto& ptxt : ptxt_strings) { - std::getline(databaseFile, ptxt, '\n'); + std::getline(*databaseStreamPtr.get(), ptxt, '\n'); } // Populate Matrix for (long i = 0; i < nrow; ++i) { @@ -77,7 +73,9 @@ helib::Database readDbFromFile(const std::string& databaseFilePath, } } else { // Ctxt query NTL_EXEC_RANGE(nrow * ncol, first, last) - Reader threadReader(reader.value()); + // Create new reader for each thread + auto databaseStreamThreadPtr = databaseStreamDispenser.get(); + Reader threadReader(databaseStreamThreadPtr, zero_txt); for (long i = first; i < last; ++i) { long row = i / ncol; long col = i % ncol; @@ -89,24 +87,20 @@ helib::Database readDbFromFile(const std::string& databaseFilePath, return helib::Database(data, contextp); } -template -helib::Matrix readQueryFromFile(const std::string& queryFilePath, - const helib::PubKey& pk) +template +helib::Matrix readQueryFromStream(const StreamDispenser& queryStreamDispenser, + const helib::PubKey& pk) { - // Read in TXT file header - std::ifstream queryFile(queryFilePath); - if (!queryFile.is_open()) { - throw std::runtime_error("Could not open file '" + queryFilePath + "'."); - } - + auto queryStreamPtr = queryStreamDispenser.get(); TXT zero_txt(pk); // This is only needed for TXT = Ctxt std::optional> reader; long nrow, ncol; if constexpr (std::is_same_v) { // Ptxt query - std::tie(nrow, ncol) = parseDimsHeader(readline(queryFile)); + std::tie(nrow, ncol) = parseDimsHeader(readline(*queryStreamPtr.get())); } else { // Ctxt query - reader.emplace(Reader(queryFilePath, zero_txt)); + // TODO + reader.emplace(queryStreamPtr, zero_txt); nrow = reader.value().getTOC().getRows(); ncol = reader.value().getTOC().getCols(); } @@ -122,7 +116,7 @@ helib::Matrix readQueryFromFile(const std::string& queryFilePath, // Read in ptxts std::vector ptxt_strings(nrow * ncol); for (auto& ptxt : ptxt_strings) { - std::getline(queryFile, ptxt, '\n'); + std::getline(*queryStreamPtr.get(), ptxt, '\n'); } // Populate Matrix for (long i = 0; i < ptxt_strings.size(); ++i) { @@ -132,7 +126,9 @@ helib::Matrix readQueryFromFile(const std::string& queryFilePath, } else { // Ctxt query // Read in ctxts NTL_EXEC_RANGE(nrow * ncol, first, last) - Reader threadReader(reader.value()); + // Create new reader for each thread + auto queryStreamThreadPtr = queryStreamDispenser.get(); + Reader threadReader(queryStreamThreadPtr, zero_txt); for (long i = first; i < last; ++i) { long row = i / ncol; long col = i % ncol; @@ -147,6 +143,35 @@ helib::Matrix readQueryFromFile(const std::string& queryFilePath, return query; } +template +helib::Database readDbFromFile(const std::string& databaseFilePath, + const sharedContext& contextp, + const helib::PubKey& pk) +{ + auto streamDispenser = make_stream_dispenser(databaseFilePath); + // Read in TXT file header + // std::ifstream databaseFile(databaseFilePath); + // if (!databaseFile.is_open()) { + // throw std::runtime_error("Could not open file '" + databaseFilePath + "'."); + // } + + return readDbFromStream(streamDispenser, contextp, pk); +} + +template +helib::Matrix readQueryFromFile(const std::string& queryFilePath, + const helib::PubKey& pk) +{ + auto streamDispenser = make_stream_dispenser(queryFilePath); + // Read in TXT file header + // std::ifstream queryFile(queryFilePath); + // if (!queryFile.is_open()) { + // throw std::runtime_error("Could not open file '" + queryFilePath + "'."); + // } + + return readQueryFromStream(streamDispenser, pk); +} + // Writes out a matrix to file template inline void writeResultsToFile(const std::string& outFilePath, diff --git a/utils/common/Reader.h b/utils/common/Reader.h index cecdae36c..60beb88df 100644 --- a/utils/common/Reader.h +++ b/utils/common/Reader.h @@ -18,6 +18,7 @@ #include #include +#include "stream_dispenser.h" #include "TOC.h" template @@ -26,57 +27,70 @@ class Reader private: const std::string filepath; - std::ifstream readStream; + std::unique_ptr streamPtr; D& scratch; std::shared_ptr toc; public: Reader(const std::string& fname, D& init) : filepath(fname), - readStream(filepath, std::ios::binary), + streamPtr(std::make_unique(filepath, std::ios::binary)), scratch(init), toc(std::make_shared()) { - if (!readStream.is_open()) + const std::ifstream& stream = *reinterpret_cast(streamPtr.get()); + if (!stream.is_open()) throw std::runtime_error("Could not open '" + filepath + "'."); - toc->read(readStream); + toc->read(*streamPtr); + } + + template + Reader(std::unique_ptr& istreamPtr, D& init) : + filepath("__STREAM__"), + streamPtr(std::move(istreamPtr)), + scratch(init), + toc(std::make_shared()) + { + toc->read(*streamPtr); } Reader(const Reader& rdr) : filepath(rdr.filepath), - readStream(filepath, std::ios::binary), + // This allows "copying" of streams by creating unique instances of the same stream + streamPtr(make_stream_dispenser(filepath, std::ios::binary).get()), scratch(rdr.scratch), toc(rdr.toc) { - if (!readStream.is_open()) + const std::ifstream& stream = *reinterpret_cast(streamPtr.get()); + if (!stream.is_open()) throw std::runtime_error("Could not open '" + rdr.filepath + "'."); } void readDatum(D& dest, int i, int j) { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); - readStream.seekg(toc->getIdx(i, j)); - dest.read(readStream); + streamPtr->seekg(toc->getIdx(i, j)); + dest.read(*streamPtr); } std::unique_ptr readDatum(int i, int j) { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); std::unique_ptr ptr = std::make_unique(scratch); - readStream.seekg(toc->getIdx(i, j)); - ptr->read(readStream); + streamPtr->seekg(toc->getIdx(i, j)); + ptr->read(*streamPtr); return std::move(ptr); } std::unique_ptr>> readAll() { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); auto m_ptr = std::make_unique>>( toc->getRows(), @@ -84,8 +98,8 @@ class Reader for (int i = 0; i < toc->getRows(); i++) { for (int j = 0; j < toc->getCols(); j++) { - readStream.seekg(toc->getIdx(i, j)); - (*m_ptr)[i][j].read(readStream); + streamPtr->seekg(toc->getIdx(i, j)); + (*m_ptr)[i][j].read(*streamPtr); } } @@ -95,13 +109,13 @@ class Reader std::unique_ptr> readRow(int i) { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); auto v_ptr = std::make_unique>(toc->getCols(), scratch); for (int n = 0; n < toc->getCols(); n++) { - readStream.seekg(toc->getIdx(i, n)); - (*v_ptr)[n].read(readStream); + streamPtr->seekg(toc->getIdx(i, n)); + (*v_ptr)[n].read(*streamPtr); } return std::move(v_ptr); @@ -109,13 +123,13 @@ class Reader std::unique_ptr> readCol(int j) { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); auto v_ptr = std::make_unique>(toc->getRows(), scratch); for (int n = 0; n < toc->getRows(); n++) { - readStream.seekg(toc->getIdx(n, j)); - (*v_ptr)[n].read(readStream); + streamPtr->seekg(toc->getIdx(n, j)); + (*v_ptr)[n].read(*streamPtr); } return std::move(v_ptr); diff --git a/utils/common/stream_dispenser.h b/utils/common/stream_dispenser.h new file mode 100644 index 000000000..3c024cf7d --- /dev/null +++ b/utils/common/stream_dispenser.h @@ -0,0 +1,35 @@ +/* Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef STREAM_DISPENSER_H_ +#define STREAM_DISPENSER_H_ + +#include +#include +#include + +// Do not use this class directly. Create instances using the +// make_stream_dispenser function +// eg. make_stream_dispenser("filename", std::ios::binary) +template +class StreamDispenser { + private: + std::tuple args; + + public: + StreamDispenser(Args... args): args(std::make_tuple(args...)) {} + + std::unique_ptr get() const { + return std::apply([](const Args... args){ + return std::make_unique(args...); + }, args); + } +}; + +template +inline auto make_stream_dispenser(Args... args) { + return StreamDispenser(args...); +} + +#endif // STREAM_DISPENSER_H_