diff --git a/s3sqlite.py b/s3sqlite.py index debe947..7900a9e 100644 --- a/s3sqlite.py +++ b/s3sqlite.py @@ -4,7 +4,7 @@ import logging import sys from typing import Optional - +from smart_open import open as sopen logger = logging.getLogger("s3sqlite") handler = logging.StreamHandler(sys.stderr) @@ -55,7 +55,67 @@ def convert_flags(flags): return hexify(flags) else: raise ValueError(flags) + +class SmartOpenVFS(apsw.VFS): + def __init__( + self, + name: str, + block_size: int = 4096, + file_kwargs: Optional[dict] = None, + ): + """ + APSW VFS to read by ranges from S3. + + Args: + * name: S3 path of the file (bucket + prefix + filename) + * block_size: Block size used by the filesystem. + * file_kwargs: Extra arguments to pass when calling the open() method of fs (smartopen) + This may be useful to configure the cache strategy used by the S3FileSystem + """ + self.name = f"{name}-{str(uuid.uuid4())}" + self.block_size = block_size + self.file_kwargs = file_kwargs if file_kwargs else {} + super().__init__(name=self.name, base="") + + def xAccess(self, pathname, flags): + try: + with sopen(pathname): + return True + except Exception: + return False + def xFullPathname(self, filename): + logger.debug("Calling VFS xFullPathname") + logger.debug(f"Name: {self.name}") + logger.debug(filename) + return filename + + def xDelete(self, filename, syncdir): + logger.debug("Calling VFS xDelete") + logger.debug( + f"Name: {self.name} filename: {filename}, syncdir: {syncdir}" + ) + pass + + def xOpen(self, name, flags): + # TODO: check flags to make sure the DB is openned in read-only mode. + logger.debug("Calling VFS xOpen") + fname = name.filename() if isinstance(name, apsw.URIFilename) else name + logger.debug( + f"Name: {self.name} open_name: {fname}, flags: {convert_flags(flags)}" + ) + + print("smart_open", name) + + ofile = sopen( + fname, mode="rb", **self.file_kwargs + ) + + return VFSFile( + f=ofile, + name=fname, + flags=flags, + ) class S3VFS(apsw.VFS): def __init__( @@ -113,7 +173,7 @@ def xOpen(self, name, flags): fname, mode="rb", block_size=self.block_size, **self.file_kwargs ) - return S3VFSFile( + return VFSFile( f=ofile, name=fname, flags=flags, @@ -123,8 +183,8 @@ def upload_file(self, dbfile, dest): self.fs.upload(dbfile, dest) -class S3VFSFile(apsw.VFSFile): - def __init__(self, f: s3fs.S3File, name, flags): +class VFSFile(apsw.VFSFile): + def __init__(self, f: ..., name, flags): """ VFS File object @@ -143,11 +203,11 @@ def __init__(self, f: s3fs.S3File, name, flags): def xRead(self, amount, offset) -> bytes: logger.debug("Calling file xRead") logger.debug( - f"Name: {self.name} file: {self.f.path}, amount: {amount} offset: {offset}" + f"Name: {self.name} amount: {amount} offset: {offset}" ) self.f.seek(offset) data = self.f.read(amount) - logger.debug(f"Read data: {data}") + # logger.debug(f"Read data: {data}") return data def xFileControl(self, *args): @@ -171,13 +231,13 @@ def xSectorSize(self): def xClose(self): logger.debug("Calling file xClose") - logger.debug(f"Name: {self.name} file: {self.f.path}") + logger.debug(f"Name: {self.name}") self.f.close() pass def xFileSize(self): logger.debug("Calling file xFileSize") - logger.debug(f"Name: {self.name} file: {self.f.path}") + logger.debug(f"Name: {self.name}") pos = self.f.tell() self.f.seek(0, 2) size = self.f.tell() @@ -188,7 +248,7 @@ def xFileSize(self): def xSync(self, flags): logger.debug("Calling file xSync") logger.debug( - f"Name: {self.name} file: {self.f.path}, flags: {convert_flags(flags)}" + f"Name: {self.name} flags: {convert_flags(flags)}" ) pass @@ -200,6 +260,6 @@ def xTruncate(self, newsize): def xWrite(self, data, offset): logger.debug("Calling file xWrite") logger.debug( - f"Name: {self.name} file: {self.f.path}, data_size: {len(data)}, offset: {offset}, data: {data}" + f"Name: {self.name} data_size: {len(data)}, offset: {offset}, data: {data}" ) pass diff --git a/test.py b/test.py index b7ccd35..1f9587a 100644 --- a/test.py +++ b/test.py @@ -157,6 +157,21 @@ def localvfs(local_fs): return s3sqlite.S3VFS(name="local-vfs", fs=local_fs) +@pytest.fixture +def smartopenvfs(s3_data): + client = boto3.client( + "s3", + endpoint_url=s3_data["endpoint_url"], + aws_access_key_id=s3_data["key"], + aws_secret_access_key=s3_data["secret"], + region_name="us-east-1", + ) + yield s3sqlite.SmartOpenVFS( + name="smart-open-vfs", + file_kwargs={"transport_params": {"client": client}}, + ) + + @contextmanager def transaction(conn): conn.execute("BEGIN;") @@ -274,3 +289,35 @@ def test_s3vfs_query(bucket, s3vfs, get_db, query): local_c = get_db[1].execute(query) c = conn.execute(query) assert c.fetchall() == local_c.fetchall() + + +@pytest.mark.parametrize("query", QUERIES) +def test_smartopenvfs_query_wal(bucket, s3vfs, smartopenvfs, get_db_wal, query): + + key_prefix = f"{bucket}/{dbname}" + s3vfs.upload_file(get_db_wal[0], dest=key_prefix) + + # SmartOpenVFS requires an s3:// URI + with apsw.Connection( + f"s3://{key_prefix}", vfs=smartopenvfs.name, flags=apsw.SQLITE_OPEN_READONLY + ) as conn: + + local_c = get_db_wal[1].execute(query) + c = conn.execute(query) + assert c.fetchall() == local_c.fetchall() + + +@pytest.mark.parametrize("query", QUERIES) +def test_smartopenvfs_query(bucket, s3vfs, smartopenvfs, get_db, query): + + key_prefix = f"{bucket}/{dbname}" + s3vfs.upload_file(get_db[0], dest=key_prefix) + + # SmartOpenVFS requires an s3:// URI + with apsw.Connection( + f"s3://{key_prefix}", vfs=smartopenvfs.name, flags=apsw.SQLITE_OPEN_READONLY + ) as conn: + + local_c = get_db[1].execute(query) + c = conn.execute(query) + assert c.fetchall() == local_c.fetchall()