Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 52 additions & 27 deletions misc/psi/io/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,21 @@
using sharedContext = std::shared_ptr<helib::Context>;
using Ptxt = helib::Ptxt<helib::BGV>;

template <typename TXT>
helib::Database<TXT> readDbFromFile(const std::string& databaseFilePath,
const sharedContext& contextp,
const helib::PubKey& pk)
template <typename TXT, typename StreamDispenser>
helib::Database<TXT> readDbFromStream(const StreamDispenser& databaseStreamDispenser,
const sharedContext& contextp,
const helib::PubKey& pk)
{
// Read in TXT file header
std::ifstream databaseFile(databaseFilePath);
if (!databaseFile.is_open()) {
throw std::runtime_error("Could not open file '" + databaseFilePath + "'.");
}

auto databaseStreamPtr = databaseStreamDispenser.get();
TXT zero_txt(pk);
// This is only needed for TXT = Ctxt
std::optional<Reader<TXT>> reader;
long nrow, ncol;
if constexpr (std::is_same_v<TXT, Ptxt>) {
std::tie(nrow, ncol) = parseDimsHeader(readline(databaseFile));
std::tie(nrow, ncol) = parseDimsHeader(readline(*databaseStreamPtr.get()));
} else {
reader.emplace(Reader<helib::Ctxt>(databaseFilePath, zero_txt));
// TODO
reader.emplace(databaseStreamPtr, zero_txt);
nrow = reader.value().getTOC().getRows();
ncol = reader.value().getTOC().getCols();
}
Expand All @@ -66,7 +62,7 @@ helib::Database<TXT> readDbFromFile(const std::string& databaseFilePath,
// Read in ptxts
std::vector<std::string> ptxt_strings(nrow * ncol);
for (auto& ptxt : ptxt_strings) {
std::getline(databaseFile, ptxt, '\n');
std::getline(*databaseStreamPtr.get(), ptxt, '\n');
}
// Populate Matrix
for (long i = 0; i < nrow; ++i) {
Expand All @@ -77,7 +73,9 @@ helib::Database<TXT> readDbFromFile(const std::string& databaseFilePath,
}
} else { // Ctxt query
NTL_EXEC_RANGE(nrow * ncol, first, last)
Reader<TXT> threadReader(reader.value());
// Create new reader for each thread
auto databaseStreamThreadPtr = databaseStreamDispenser.get();
Reader<TXT> threadReader(databaseStreamThreadPtr, zero_txt);
for (long i = first; i < last; ++i) {
long row = i / ncol;
long col = i % ncol;
Expand All @@ -89,24 +87,20 @@ helib::Database<TXT> readDbFromFile(const std::string& databaseFilePath,
return helib::Database<TXT>(data, contextp);
}

template <typename TXT>
helib::Matrix<TXT> readQueryFromFile(const std::string& queryFilePath,
const helib::PubKey& pk)
template <typename TXT, typename StreamDispenser>
helib::Matrix<TXT> readQueryFromStream(const StreamDispenser& queryStreamDispenser,
const helib::PubKey& pk)
{
// Read in TXT file header
std::ifstream queryFile(queryFilePath);
if (!queryFile.is_open()) {
throw std::runtime_error("Could not open file '" + queryFilePath + "'.");
}

auto queryStreamPtr = queryStreamDispenser.get();
TXT zero_txt(pk);
// This is only needed for TXT = Ctxt
std::optional<Reader<TXT>> reader;
long nrow, ncol;
if constexpr (std::is_same_v<TXT, Ptxt>) { // Ptxt query
std::tie(nrow, ncol) = parseDimsHeader(readline(queryFile));
std::tie(nrow, ncol) = parseDimsHeader(readline(*queryStreamPtr.get()));
} else { // Ctxt query
reader.emplace(Reader<helib::Ctxt>(queryFilePath, zero_txt));
// TODO
reader.emplace(queryStreamPtr, zero_txt);
nrow = reader.value().getTOC().getRows();
ncol = reader.value().getTOC().getCols();
}
Expand All @@ -122,7 +116,7 @@ helib::Matrix<TXT> readQueryFromFile(const std::string& queryFilePath,
// Read in ptxts
std::vector<std::string> ptxt_strings(nrow * ncol);
for (auto& ptxt : ptxt_strings) {
std::getline(queryFile, ptxt, '\n');
std::getline(*queryStreamPtr.get(), ptxt, '\n');
}
// Populate Matrix
for (long i = 0; i < ptxt_strings.size(); ++i) {
Expand All @@ -132,7 +126,9 @@ helib::Matrix<TXT> readQueryFromFile(const std::string& queryFilePath,
} else { // Ctxt query
// Read in ctxts
NTL_EXEC_RANGE(nrow * ncol, first, last)
Reader<TXT> threadReader(reader.value());
// Create new reader for each thread
auto queryStreamThreadPtr = queryStreamDispenser.get();
Reader<TXT> threadReader(queryStreamThreadPtr, zero_txt);
for (long i = first; i < last; ++i) {
long row = i / ncol;
long col = i % ncol;
Expand All @@ -147,6 +143,35 @@ helib::Matrix<TXT> readQueryFromFile(const std::string& queryFilePath,
return query;
}

template <typename TXT>
helib::Database<TXT> readDbFromFile(const std::string& databaseFilePath,
const sharedContext& contextp,
const helib::PubKey& pk)
{
auto streamDispenser = make_stream_dispenser<std::ifstream>(databaseFilePath);
// Read in TXT file header
// std::ifstream databaseFile(databaseFilePath);
// if (!databaseFile.is_open()) {
// throw std::runtime_error("Could not open file '" + databaseFilePath + "'.");
// }

return readDbFromStream<TXT>(streamDispenser, contextp, pk);
}

template <typename TXT>
helib::Matrix<TXT> readQueryFromFile(const std::string& queryFilePath,
const helib::PubKey& pk)
{
auto streamDispenser = make_stream_dispenser<std::ifstream>(queryFilePath);
// Read in TXT file header
// std::ifstream queryFile(queryFilePath);
// if (!queryFile.is_open()) {
// throw std::runtime_error("Could not open file '" + queryFilePath + "'.");
// }

return readQueryFromStream<TXT>(streamDispenser, pk);
}

// Writes out a matrix to file
template <typename TXT>
inline void writeResultsToFile(const std::string& outFilePath,
Expand Down
66 changes: 40 additions & 26 deletions utils/common/Reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <fstream>
#include <exception>

#include "stream_dispenser.h"
#include "TOC.h"

template <typename D>
Expand All @@ -26,66 +27,79 @@ class Reader

private:
const std::string filepath;
std::ifstream readStream;
std::unique_ptr<std::istream> streamPtr;
D& scratch;
std::shared_ptr<TOC> toc;

public:
Reader(const std::string& fname, D& init) :
filepath(fname),
readStream(filepath, std::ios::binary),
streamPtr(std::make_unique<std::ifstream>(filepath, std::ios::binary)),
scratch(init),
toc(std::make_shared<TOC>())
{
if (!readStream.is_open())
const std::ifstream& stream = *reinterpret_cast<std::ifstream*>(streamPtr.get());
if (!stream.is_open())
throw std::runtime_error("Could not open '" + filepath + "'.");
toc->read(readStream);
toc->read(*streamPtr);
}

template <typename STREAM>
Reader(std::unique_ptr<STREAM>& istreamPtr, D& init) :
filepath("__STREAM__"),
streamPtr(std::move(istreamPtr)),
scratch(init),
toc(std::make_shared<TOC>())
{
toc->read(*streamPtr);
}

Reader(const Reader& rdr) :
filepath(rdr.filepath),
readStream(filepath, std::ios::binary),
// This allows "copying" of streams by creating unique instances of the same stream
streamPtr(make_stream_dispenser<std::ifstream>(filepath, std::ios::binary).get()),
scratch(rdr.scratch),
toc(rdr.toc)
{
if (!readStream.is_open())
const std::ifstream& stream = *reinterpret_cast<std::ifstream*>(streamPtr.get());
if (!stream.is_open())
throw std::runtime_error("Could not open '" + rdr.filepath + "'.");
}

void readDatum(D& dest, int i, int j)
{
if (readStream.eof())
readStream.clear();
if (streamPtr->eof())
streamPtr->clear();

readStream.seekg(toc->getIdx(i, j));
dest.read(readStream);
streamPtr->seekg(toc->getIdx(i, j));
dest.read(*streamPtr);
}

std::unique_ptr<D> readDatum(int i, int j)
{
if (readStream.eof())
readStream.clear();
if (streamPtr->eof())
streamPtr->clear();

std::unique_ptr<D> ptr = std::make_unique<D>(scratch);
readStream.seekg(toc->getIdx(i, j));
ptr->read(readStream);
streamPtr->seekg(toc->getIdx(i, j));
ptr->read(*streamPtr);

return std::move(ptr);
}

std::unique_ptr<std::vector<std::vector<D>>> readAll()
{
if (readStream.eof())
readStream.clear();
if (streamPtr->eof())
streamPtr->clear();

auto m_ptr = std::make_unique<std::vector<std::vector<D>>>(
toc->getRows(),
std::vector<D>(toc->getCols(), scratch));

for (int i = 0; i < toc->getRows(); i++) {
for (int j = 0; j < toc->getCols(); j++) {
readStream.seekg(toc->getIdx(i, j));
(*m_ptr)[i][j].read(readStream);
streamPtr->seekg(toc->getIdx(i, j));
(*m_ptr)[i][j].read(*streamPtr);
}
}

Expand All @@ -95,27 +109,27 @@ class Reader
std::unique_ptr<std::vector<D>> readRow(int i)
{

if (readStream.eof())
readStream.clear();
if (streamPtr->eof())
streamPtr->clear();

auto v_ptr = std::make_unique<std::vector<D>>(toc->getCols(), scratch);
for (int n = 0; n < toc->getCols(); n++) {
readStream.seekg(toc->getIdx(i, n));
(*v_ptr)[n].read(readStream);
streamPtr->seekg(toc->getIdx(i, n));
(*v_ptr)[n].read(*streamPtr);
}

return std::move(v_ptr);
}

std::unique_ptr<std::vector<D>> readCol(int j)
{
if (readStream.eof())
readStream.clear();
if (streamPtr->eof())
streamPtr->clear();

auto v_ptr = std::make_unique<std::vector<D>>(toc->getRows(), scratch);
for (int n = 0; n < toc->getRows(); n++) {
readStream.seekg(toc->getIdx(n, j));
(*v_ptr)[n].read(readStream);
streamPtr->seekg(toc->getIdx(n, j));
(*v_ptr)[n].read(*streamPtr);
}

return std::move(v_ptr);
Expand Down
35 changes: 35 additions & 0 deletions utils/common/stream_dispenser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (C) 2022 Intel Corporation
* SPDX-License-Identifier: Apache-2.0
*/

#ifndef STREAM_DISPENSER_H_
#define STREAM_DISPENSER_H_

#include<functional>
#include<tuple>
#include<memory>

// Do not use this class directly. Create instances using the
// make_stream_dispenser function
// eg. make_stream_dispenser<std::ifstream>("filename", std::ios::binary)
template<typename Stream, typename... Args>
class StreamDispenser {
private:
std::tuple<Args...> args;

public:
StreamDispenser(Args... args): args(std::make_tuple(args...)) {}

std::unique_ptr<Stream> get() const {
return std::apply([](const Args... args){
return std::make_unique<Stream>(args...);
}, args);
}
};

template<typename Stream, typename... Args>
inline auto make_stream_dispenser(Args... args) {
return StreamDispenser<Stream, Args...>(args...);
}

#endif // STREAM_DISPENSER_H_