From 3587c848e60cadb214884b617798724a1bc94132 Mon Sep 17 00:00:00 2001 From: Jack Crawford Date: Thu, 10 Nov 2022 13:07:48 -0800 Subject: [PATCH 1/6] Init PSI io refactor --- misc/psi/io/io.h | 63 ++++++++++++++++++++++++++--------------- utils/common/Reader.h | 66 +++++++++++++++++++++++++------------------ 2 files changed, 78 insertions(+), 51 deletions(-) diff --git a/misc/psi/io/io.h b/misc/psi/io/io.h index ee0059bf7..87965b5e8 100644 --- a/misc/psi/io/io.h +++ b/misc/psi/io/io.h @@ -38,24 +38,19 @@ using sharedContext = std::shared_ptr; using Ptxt = helib::Ptxt; template -helib::Database readDbFromFile(const std::string& databaseFilePath, - const sharedContext& contextp, - const helib::PubKey& pk) +helib::Database readDbFromStream(std::istream& databaseFileStream, + const sharedContext& contextp, + const helib::PubKey& pk) { - // Read in TXT file header - std::ifstream databaseFile(databaseFilePath); - if (!databaseFile.is_open()) { - throw std::runtime_error("Could not open file '" + databaseFilePath + "'."); - } - TXT zero_txt(pk); // This is only needed for TXT = Ctxt std::optional> reader; long nrow, ncol; if constexpr (std::is_same_v) { - std::tie(nrow, ncol) = parseDimsHeader(readline(databaseFile)); + std::tie(nrow, ncol) = parseDimsHeader(readline(databaseFileStream)); } else { - reader.emplace(Reader(databaseFilePath, zero_txt)); + // TODO + reader.emplace(Reader(databaseFileStream, zero_txt)); nrow = reader.value().getTOC().getRows(); ncol = reader.value().getTOC().getCols(); } @@ -66,7 +61,7 @@ helib::Database readDbFromFile(const std::string& databaseFilePath, // Read in ptxts std::vector ptxt_strings(nrow * ncol); for (auto& ptxt : ptxt_strings) { - std::getline(databaseFile, ptxt, '\n'); + std::getline(databaseFileStream, ptxt, '\n'); } // Populate Matrix for (long i = 0; i < nrow; ++i) { @@ -87,23 +82,18 @@ helib::Database readDbFromFile(const std::string& databaseFilePath, } template -helib::Matrix readQueryFromFile(const std::string& queryFilePath, - const helib::PubKey& pk) +helib::Matrix readQueryFromStream(std::istream& queryFileStream, + const helib::PubKey& pk) { - // Read in TXT file header - std::ifstream queryFile(queryFilePath); - if (!queryFile.is_open()) { - throw std::runtime_error("Could not open file '" + queryFilePath + "'."); - } - TXT zero_txt(pk); // This is only needed for TXT = Ctxt std::optional> reader; long nrow, ncol; if constexpr (std::is_same_v) { // Ptxt query - std::tie(nrow, ncol) = parseDimsHeader(readline(queryFile)); + std::tie(nrow, ncol) = parseDimsHeader(readline(queryFileStream)); } else { // Ctxt query - reader.emplace(Reader(queryFilePath, zero_txt)); + // TODO + reader.emplace(Reader(queryFileStream, zero_txt)); nrow = reader.value().getTOC().getRows(); ncol = reader.value().getTOC().getCols(); } @@ -119,7 +109,7 @@ helib::Matrix readQueryFromFile(const std::string& queryFilePath, // Read in ptxts std::vector ptxt_strings(nrow * ncol); for (auto& ptxt : ptxt_strings) { - std::getline(queryFile, ptxt, '\n'); + std::getline(queryFileStream, ptxt, '\n'); } // Populate Matrix for (long i = 0; i < ptxt_strings.size(); ++i) { @@ -141,6 +131,33 @@ helib::Matrix readQueryFromFile(const std::string& queryFilePath, return query; } +template +helib::Database readDbFromFile(const std::string& databaseFilePath, + const sharedContext& contextp, + const helib::PubKey& pk) +{ + // Read in TXT file header + std::ifstream databaseFile(databaseFilePath); + if (!databaseFile.is_open()) { + throw std::runtime_error("Could not open file '" + databaseFilePath + "'."); + } + + return readDbFromStream(databaseFile, contextp, pk); +} + +template +helib::Matrix readQueryFromFile(const std::string& queryFilePath, + const helib::PubKey& pk) +{ + // Read in TXT file header + std::ifstream queryFile(queryFilePath); + if (!queryFile.is_open()) { + throw std::runtime_error("Could not open file '" + queryFilePath + "'."); + } + + return readQueryFromStream(queryFile, pk); +} + // Writes out a matrix to file template inline void writeResultsToFile(const std::string& outFilePath, diff --git a/utils/common/Reader.h b/utils/common/Reader.h index cecdae36c..ef9cfcae1 100644 --- a/utils/common/Reader.h +++ b/utils/common/Reader.h @@ -26,57 +26,67 @@ class Reader private: const std::string filepath; - std::ifstream readStream; + std::shared_ptr streamPtr; + // std::ifstream readStream; D& scratch; std::shared_ptr toc; public: Reader(const std::string& fname, D& init) : filepath(fname), - readStream(filepath, std::ios::binary), + streamPtr(std::make_shared(filepath, std::ios::binary)), scratch(init), toc(std::make_shared()) { - if (!readStream.is_open()) - throw std::runtime_error("Could not open '" + filepath + "'."); - toc->read(readStream); + // if (!streamPtr->is_open()) + // throw std::runtime_error("Could not open '" + filepath + "'."); + toc->read(*streamPtr); + } + + Reader(std::istream& istream, D& init) : + filepath("__STREAM__"), + streamPtr(istream), + scratch(init), + toc(std::make_shared()) + { + toc->read(*streamPtr); } Reader(const Reader& rdr) : filepath(rdr.filepath), - readStream(filepath, std::ios::binary), + streamPtr(std::make_shared(filepath, std::ios::binary)), scratch(rdr.scratch), toc(rdr.toc) { - if (!readStream.is_open()) - throw std::runtime_error("Could not open '" + rdr.filepath + "'."); + // if (!streamPtr->is_open()) + // throw std::runtime_error("Could not open '" + rdr.filepath + "'."); } void readDatum(D& dest, int i, int j) { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); - readStream.seekg(toc->getIdx(i, j)); - dest.read(readStream); + streamPtr->seekg(toc->getIdx(i, j)); + dest.read(*streamPtr); } std::unique_ptr readDatum(int i, int j) { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); std::unique_ptr ptr = std::make_unique(scratch); - readStream.seekg(toc->getIdx(i, j)); - ptr->read(readStream); + streamPtr->seekg(toc->getIdx(i, j)); + ptr->read(*streamPtr); return std::move(ptr); } std::unique_ptr>> readAll() { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); auto m_ptr = std::make_unique>>( toc->getRows(), @@ -84,8 +94,8 @@ class Reader for (int i = 0; i < toc->getRows(); i++) { for (int j = 0; j < toc->getCols(); j++) { - readStream.seekg(toc->getIdx(i, j)); - (*m_ptr)[i][j].read(readStream); + streamPtr->seekg(toc->getIdx(i, j)); + (*m_ptr)[i][j].read(*streamPtr); } } @@ -95,13 +105,13 @@ class Reader std::unique_ptr> readRow(int i) { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); auto v_ptr = std::make_unique>(toc->getCols(), scratch); for (int n = 0; n < toc->getCols(); n++) { - readStream.seekg(toc->getIdx(i, n)); - (*v_ptr)[n].read(readStream); + streamPtr->seekg(toc->getIdx(i, n)); + (*v_ptr)[n].read(*streamPtr); } return std::move(v_ptr); @@ -109,13 +119,13 @@ class Reader std::unique_ptr> readCol(int j) { - if (readStream.eof()) - readStream.clear(); + if (streamPtr->eof()) + streamPtr->clear(); auto v_ptr = std::make_unique>(toc->getRows(), scratch); for (int n = 0; n < toc->getRows(); n++) { - readStream.seekg(toc->getIdx(n, j)); - (*v_ptr)[n].read(readStream); + streamPtr->seekg(toc->getIdx(n, j)); + (*v_ptr)[n].read(*streamPtr); } return std::move(v_ptr); From 532912210f74267562a8fb299b60ab8d1e3859f5 Mon Sep 17 00:00:00 2001 From: Hamish Hunt Date: Fri, 11 Nov 2022 13:17:46 +0000 Subject: [PATCH 2/6] stream dispenser --- misc/psi/io/stream_dispenser.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 misc/psi/io/stream_dispenser.h diff --git a/misc/psi/io/stream_dispenser.h b/misc/psi/io/stream_dispenser.h new file mode 100644 index 000000000..4f9fdee56 --- /dev/null +++ b/misc/psi/io/stream_dispenser.h @@ -0,0 +1,20 @@ +template +class StreamDispenser { + private: + std::tuple args; + + public: + StreamDispenser(Args... args): args(std::make_tuple(args...)) {} + + std::unique_ptr get() const { + return std::apply([](const Args... args){ + return std::make_unique(args...); + }, args); + } +}; + +template +inline auto make_stream_dispenser(Args... args) { + return StreamDispenser(args...); +} + From c6d043462d6422fe0f06601b9e8587649cf03cd0 Mon Sep 17 00:00:00 2001 From: Hamish Hunt Date: Fri, 11 Nov 2022 14:01:48 +0000 Subject: [PATCH 3/6] missing bits --- misc/psi/io/stream_dispenser.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/misc/psi/io/stream_dispenser.h b/misc/psi/io/stream_dispenser.h index 4f9fdee56..ff744fee2 100644 --- a/misc/psi/io/stream_dispenser.h +++ b/misc/psi/io/stream_dispenser.h @@ -1,3 +1,14 @@ +/* Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef STREAM_DISPENSER_H_ +#define STREAM_DISPENSER_H_ + +#include +#include +#include + template class StreamDispenser { private: @@ -18,3 +29,4 @@ inline auto make_stream_dispenser(Args... args) { return StreamDispenser(args...); } +#endif // STREAM_DISPENSER_H_ From 1958803371adaa8761a1df7765e4a621988fa803 Mon Sep 17 00:00:00 2001 From: Jack Crawford Date: Mon, 14 Nov 2022 13:21:58 -0800 Subject: [PATCH 4/6] Integrate stream dispenser into Reader --- misc/psi/io/io.h | 44 ++++++++++--------- utils/common/Reader.h | 24 +++++----- .../io => utils/common}/stream_dispenser.h | 3 ++ 3 files changed, 41 insertions(+), 30 deletions(-) rename {misc/psi/io => utils/common}/stream_dispenser.h (80%) diff --git a/misc/psi/io/io.h b/misc/psi/io/io.h index 87965b5e8..fff675ef3 100644 --- a/misc/psi/io/io.h +++ b/misc/psi/io/io.h @@ -37,20 +37,21 @@ using sharedContext = std::shared_ptr; using Ptxt = helib::Ptxt; -template -helib::Database readDbFromStream(std::istream& databaseFileStream, +template +helib::Database readDbFromStream(StreamDispenser& databaseStreamDispenser, const sharedContext& contextp, const helib::PubKey& pk) { + auto databaseStreamPtr = databaseStreamDispenser.get(); TXT zero_txt(pk); // This is only needed for TXT = Ctxt std::optional> reader; long nrow, ncol; if constexpr (std::is_same_v) { - std::tie(nrow, ncol) = parseDimsHeader(readline(databaseFileStream)); + std::tie(nrow, ncol) = parseDimsHeader(readline(*databaseStreamPtr.get())); } else { // TODO - reader.emplace(Reader(databaseFileStream, zero_txt)); + reader.emplace(databaseStreamPtr, zero_txt); nrow = reader.value().getTOC().getRows(); ncol = reader.value().getTOC().getCols(); } @@ -61,7 +62,7 @@ helib::Database readDbFromStream(std::istream& databaseFileStream, // Read in ptxts std::vector ptxt_strings(nrow * ncol); for (auto& ptxt : ptxt_strings) { - std::getline(databaseFileStream, ptxt, '\n'); + std::getline(*databaseStreamPtr.get(), ptxt, '\n'); } // Populate Matrix for (long i = 0; i < nrow; ++i) { @@ -81,19 +82,20 @@ helib::Database readDbFromStream(std::istream& databaseFileStream, return helib::Database(data, contextp); } -template -helib::Matrix readQueryFromStream(std::istream& queryFileStream, +template +helib::Matrix readQueryFromStream(StreamDispenser& queryStreamDispenser, const helib::PubKey& pk) { + auto queryStreamPtr = queryStreamDispenser.get(); TXT zero_txt(pk); // This is only needed for TXT = Ctxt std::optional> reader; long nrow, ncol; if constexpr (std::is_same_v) { // Ptxt query - std::tie(nrow, ncol) = parseDimsHeader(readline(queryFileStream)); + std::tie(nrow, ncol) = parseDimsHeader(readline(*queryStreamPtr.get())); } else { // Ctxt query // TODO - reader.emplace(Reader(queryFileStream, zero_txt)); + reader.emplace(queryStreamPtr, zero_txt); nrow = reader.value().getTOC().getRows(); ncol = reader.value().getTOC().getCols(); } @@ -109,7 +111,7 @@ helib::Matrix readQueryFromStream(std::istream& queryFileStream, // Read in ptxts std::vector ptxt_strings(nrow * ncol); for (auto& ptxt : ptxt_strings) { - std::getline(queryFileStream, ptxt, '\n'); + std::getline(*queryStreamPtr.get(), ptxt, '\n'); } // Populate Matrix for (long i = 0; i < ptxt_strings.size(); ++i) { @@ -136,26 +138,28 @@ helib::Database readDbFromFile(const std::string& databaseFilePath, const sharedContext& contextp, const helib::PubKey& pk) { + auto streamDispenser = make_stream_dispenser(databaseFilePath); // Read in TXT file header - std::ifstream databaseFile(databaseFilePath); - if (!databaseFile.is_open()) { - throw std::runtime_error("Could not open file '" + databaseFilePath + "'."); - } + // std::ifstream databaseFile(databaseFilePath); + // if (!databaseFile.is_open()) { + // throw std::runtime_error("Could not open file '" + databaseFilePath + "'."); + // } - return readDbFromStream(databaseFile, contextp, pk); + return readDbFromStream(streamDispenser, contextp, pk); } template helib::Matrix readQueryFromFile(const std::string& queryFilePath, const helib::PubKey& pk) { + auto streamDispenser = make_stream_dispenser(queryFilePath); // Read in TXT file header - std::ifstream queryFile(queryFilePath); - if (!queryFile.is_open()) { - throw std::runtime_error("Could not open file '" + queryFilePath + "'."); - } + // std::ifstream queryFile(queryFilePath); + // if (!queryFile.is_open()) { + // throw std::runtime_error("Could not open file '" + queryFilePath + "'."); + // } - return readQueryFromStream(queryFile, pk); + return readQueryFromStream(streamDispenser, pk); } // Writes out a matrix to file diff --git a/utils/common/Reader.h b/utils/common/Reader.h index ef9cfcae1..60beb88df 100644 --- a/utils/common/Reader.h +++ b/utils/common/Reader.h @@ -18,6 +18,7 @@ #include #include +#include "stream_dispenser.h" #include "TOC.h" template @@ -26,26 +27,27 @@ class Reader private: const std::string filepath; - std::shared_ptr streamPtr; - // std::ifstream readStream; + std::unique_ptr streamPtr; D& scratch; std::shared_ptr toc; public: Reader(const std::string& fname, D& init) : filepath(fname), - streamPtr(std::make_shared(filepath, std::ios::binary)), + streamPtr(std::make_unique(filepath, std::ios::binary)), scratch(init), toc(std::make_shared()) { - // if (!streamPtr->is_open()) - // throw std::runtime_error("Could not open '" + filepath + "'."); + const std::ifstream& stream = *reinterpret_cast(streamPtr.get()); + if (!stream.is_open()) + throw std::runtime_error("Could not open '" + filepath + "'."); toc->read(*streamPtr); } - Reader(std::istream& istream, D& init) : + template + Reader(std::unique_ptr& istreamPtr, D& init) : filepath("__STREAM__"), - streamPtr(istream), + streamPtr(std::move(istreamPtr)), scratch(init), toc(std::make_shared()) { @@ -54,12 +56,14 @@ class Reader Reader(const Reader& rdr) : filepath(rdr.filepath), - streamPtr(std::make_shared(filepath, std::ios::binary)), + // This allows "copying" of streams by creating unique instances of the same stream + streamPtr(make_stream_dispenser(filepath, std::ios::binary).get()), scratch(rdr.scratch), toc(rdr.toc) { - // if (!streamPtr->is_open()) - // throw std::runtime_error("Could not open '" + rdr.filepath + "'."); + const std::ifstream& stream = *reinterpret_cast(streamPtr.get()); + if (!stream.is_open()) + throw std::runtime_error("Could not open '" + rdr.filepath + "'."); } void readDatum(D& dest, int i, int j) diff --git a/misc/psi/io/stream_dispenser.h b/utils/common/stream_dispenser.h similarity index 80% rename from misc/psi/io/stream_dispenser.h rename to utils/common/stream_dispenser.h index ff744fee2..3c024cf7d 100644 --- a/misc/psi/io/stream_dispenser.h +++ b/utils/common/stream_dispenser.h @@ -9,6 +9,9 @@ #include #include +// Do not use this class directly. Create instances using the +// make_stream_dispenser function +// eg. make_stream_dispenser("filename", std::ios::binary) template class StreamDispenser { private: From 9a830612548ca556ebf2c8ca1c45454346b83f14 Mon Sep 17 00:00:00 2001 From: Jack Crawford Date: Tue, 15 Nov 2022 08:49:37 -0800 Subject: [PATCH 5/6] Fixed bats tests --- misc/psi/io/io.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/misc/psi/io/io.h b/misc/psi/io/io.h index 96ac30a4f..3d5f0607f 100644 --- a/misc/psi/io/io.h +++ b/misc/psi/io/io.h @@ -73,7 +73,9 @@ helib::Database readDbFromStream(StreamDispenser& databaseStreamDispenser, } } else { // Ctxt query NTL_EXEC_RANGE(nrow * ncol, first, last) - Reader threadReader(reader.value()); + // Create new reader for each thread + auto databaseStreamThreadPtr = databaseStreamDispenser.get(); + Reader threadReader(databaseStreamThreadPtr, zero_txt); for (long i = first; i < last; ++i) { long row = i / ncol; long col = i % ncol; @@ -124,7 +126,9 @@ helib::Matrix readQueryFromStream(StreamDispenser& queryStreamDispenser, } else { // Ctxt query // Read in ctxts NTL_EXEC_RANGE(nrow * ncol, first, last) - Reader threadReader(reader.value()); + // Create new reader for each thread + auto queryStreamThreadPtr = queryStreamDispenser.get(); + Reader threadReader(queryStreamThreadPtr, zero_txt); for (long i = first; i < last; ++i) { long row = i / ncol; long col = i % ncol; From e2a8dc8f989e1aab9b3e7e6650c52e2044fc3e63 Mon Sep 17 00:00:00 2001 From: Jack Crawford Date: Wed, 23 Nov 2022 07:58:46 -0800 Subject: [PATCH 6/6] Add const --- misc/psi/io/io.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/misc/psi/io/io.h b/misc/psi/io/io.h index 3d5f0607f..2022ac90c 100644 --- a/misc/psi/io/io.h +++ b/misc/psi/io/io.h @@ -38,7 +38,7 @@ using sharedContext = std::shared_ptr; using Ptxt = helib::Ptxt; template -helib::Database readDbFromStream(StreamDispenser& databaseStreamDispenser, +helib::Database readDbFromStream(const StreamDispenser& databaseStreamDispenser, const sharedContext& contextp, const helib::PubKey& pk) { @@ -88,7 +88,7 @@ helib::Database readDbFromStream(StreamDispenser& databaseStreamDispenser, } template -helib::Matrix readQueryFromStream(StreamDispenser& queryStreamDispenser, +helib::Matrix readQueryFromStream(const StreamDispenser& queryStreamDispenser, const helib::PubKey& pk) { auto queryStreamPtr = queryStreamDispenser.get();