Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions s3sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import sys
from typing import Optional

from smart_open import open as sopen

logger = logging.getLogger("s3sqlite")
handler = logging.StreamHandler(sys.stderr)
Expand Down Expand Up @@ -55,7 +55,67 @@ def convert_flags(flags):
return hexify(flags)
else:
raise ValueError(flags)

class SmartOpenVFS(apsw.VFS):
def __init__(
self,
name: str,
block_size: int = 4096,
file_kwargs: Optional[dict] = None,
):
"""
APSW VFS to read by ranges from S3.

Args:
* name: S3 path of the file (bucket + prefix + filename)
* block_size: Block size used by the filesystem.
* file_kwargs: Extra arguments to pass when calling the open() method of fs (smartopen)
This may be useful to configure the cache strategy used by the S3FileSystem
"""
self.name = f"{name}-{str(uuid.uuid4())}"
self.block_size = block_size
self.file_kwargs = file_kwargs if file_kwargs else {}
super().__init__(name=self.name, base="")

def xAccess(self, pathname, flags):
try:
with sopen(pathname):
return True
except Exception:
return False

def xFullPathname(self, filename):
logger.debug("Calling VFS xFullPathname")
logger.debug(f"Name: {self.name}")
logger.debug(filename)
return filename

def xDelete(self, filename, syncdir):
logger.debug("Calling VFS xDelete")
logger.debug(
f"Name: {self.name} filename: {filename}, syncdir: {syncdir}"
)
pass

def xOpen(self, name, flags):
# TODO: check flags to make sure the DB is openned in read-only mode.
logger.debug("Calling VFS xOpen")
fname = name.filename() if isinstance(name, apsw.URIFilename) else name
logger.debug(
f"Name: {self.name} open_name: {fname}, flags: {convert_flags(flags)}"
)

print("smart_open", name)

ofile = sopen(
fname, mode="rb", **self.file_kwargs
)

return VFSFile(
f=ofile,
name=fname,
flags=flags,
)

class S3VFS(apsw.VFS):
def __init__(
Expand Down Expand Up @@ -113,7 +173,7 @@ def xOpen(self, name, flags):
fname, mode="rb", block_size=self.block_size, **self.file_kwargs
)

return S3VFSFile(
return VFSFile(
f=ofile,
name=fname,
flags=flags,
Expand All @@ -123,8 +183,8 @@ def upload_file(self, dbfile, dest):
self.fs.upload(dbfile, dest)


class S3VFSFile(apsw.VFSFile):
def __init__(self, f: s3fs.S3File, name, flags):
class VFSFile(apsw.VFSFile):
def __init__(self, f: ..., name, flags):
"""
VFS File object

Expand All @@ -143,11 +203,11 @@ def __init__(self, f: s3fs.S3File, name, flags):
def xRead(self, amount, offset) -> bytes:
logger.debug("Calling file xRead")
logger.debug(
f"Name: {self.name} file: {self.f.path}, amount: {amount} offset: {offset}"
f"Name: {self.name} amount: {amount} offset: {offset}"
)
self.f.seek(offset)
data = self.f.read(amount)
logger.debug(f"Read data: {data}")
# logger.debug(f"Read data: {data}")
return data

def xFileControl(self, *args):
Expand All @@ -171,13 +231,13 @@ def xSectorSize(self):

def xClose(self):
logger.debug("Calling file xClose")
logger.debug(f"Name: {self.name} file: {self.f.path}")
logger.debug(f"Name: {self.name}")
self.f.close()
pass

def xFileSize(self):
logger.debug("Calling file xFileSize")
logger.debug(f"Name: {self.name} file: {self.f.path}")
logger.debug(f"Name: {self.name}")
pos = self.f.tell()
self.f.seek(0, 2)
size = self.f.tell()
Expand All @@ -188,7 +248,7 @@ def xFileSize(self):
def xSync(self, flags):
logger.debug("Calling file xSync")
logger.debug(
f"Name: {self.name} file: {self.f.path}, flags: {convert_flags(flags)}"
f"Name: {self.name} flags: {convert_flags(flags)}"
)
pass

Expand All @@ -200,6 +260,6 @@ def xTruncate(self, newsize):
def xWrite(self, data, offset):
logger.debug("Calling file xWrite")
logger.debug(
f"Name: {self.name} file: {self.f.path}, data_size: {len(data)}, offset: {offset}, data: {data}"
f"Name: {self.name} data_size: {len(data)}, offset: {offset}, data: {data}"
)
pass
47 changes: 47 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,21 @@ def localvfs(local_fs):
return s3sqlite.S3VFS(name="local-vfs", fs=local_fs)


@pytest.fixture
def smartopenvfs(s3_data):
client = boto3.client(
"s3",
endpoint_url=s3_data["endpoint_url"],
aws_access_key_id=s3_data["key"],
aws_secret_access_key=s3_data["secret"],
region_name="us-east-1",
)
yield s3sqlite.SmartOpenVFS(
name="smart-open-vfs",
file_kwargs={"transport_params": {"client": client}},
)


@contextmanager
def transaction(conn):
conn.execute("BEGIN;")
Expand Down Expand Up @@ -274,3 +289,35 @@ def test_s3vfs_query(bucket, s3vfs, get_db, query):
local_c = get_db[1].execute(query)
c = conn.execute(query)
assert c.fetchall() == local_c.fetchall()


@pytest.mark.parametrize("query", QUERIES)
def test_smartopenvfs_query_wal(bucket, s3vfs, smartopenvfs, get_db_wal, query):

key_prefix = f"{bucket}/{dbname}"
s3vfs.upload_file(get_db_wal[0], dest=key_prefix)

# SmartOpenVFS requires an s3:// URI
with apsw.Connection(
f"s3://{key_prefix}", vfs=smartopenvfs.name, flags=apsw.SQLITE_OPEN_READONLY
) as conn:

local_c = get_db_wal[1].execute(query)
c = conn.execute(query)
assert c.fetchall() == local_c.fetchall()


@pytest.mark.parametrize("query", QUERIES)
def test_smartopenvfs_query(bucket, s3vfs, smartopenvfs, get_db, query):

key_prefix = f"{bucket}/{dbname}"
s3vfs.upload_file(get_db[0], dest=key_prefix)

# SmartOpenVFS requires an s3:// URI
with apsw.Connection(
f"s3://{key_prefix}", vfs=smartopenvfs.name, flags=apsw.SQLITE_OPEN_READONLY
) as conn:

local_c = get_db[1].execute(query)
c = conn.execute(query)
assert c.fetchall() == local_c.fetchall()