Skip to content

Commit 8b9c53f

Browse files
committed
feat(decltypes): support to parse_decltypes
1 parent 3637f01 commit 8b9c53f

File tree

9 files changed

+719
-173
lines changed

9 files changed

+719
-173
lines changed

bandit-baseline.json

Lines changed: 288 additions & 33 deletions
Large diffs are not rendered by default.

src/sqlitecloud/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,25 @@
22
# the classes and functions from the dbapi2 module.
33
# eg: sqlite3.connect() -> sqlitecloud.connect()
44
#
5-
from .dbapi2 import Connection, Cursor, connect, register_adapter
5+
from .dbapi2 import (
6+
PARSE_COLNAMES,
7+
PARSE_DECLTYPES,
8+
Connection,
9+
Cursor,
10+
connect,
11+
register_adapter,
12+
register_converter,
13+
)
614

7-
__all__ = ["VERSION", "Connection", "Cursor", "connect", "register_adapter"]
15+
__all__ = [
16+
"VERSION",
17+
"Connection",
18+
"Cursor",
19+
"connect",
20+
"register_adapter",
21+
"register_converter",
22+
"PARSE_DECLTYPES",
23+
"PARSE_COLNAMES",
24+
]
825

926
VERSION = "0.0.79"

src/sqlitecloud/dbapi2.py

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@
5050
PARSE_DECLTYPES = 1
5151
PARSE_COLNAMES = 2
5252

53-
# Adapter registry to convert Python types to SQLite types
54-
adapters = {}
53+
# Adapters registry to convert Python types to SQLite types
54+
_adapters = {}
55+
# Converters registry to convert SQLite types to Python types
56+
_converters = {}
5557

5658

5759
@overload
@@ -106,6 +108,11 @@ def connect(
106108
It can be either a connection string or a `SqliteCloudAccount` object.
107109
config (Optional[SQLiteCloudConfig]): The configuration options for the connection.
108110
Defaults to None.
111+
detect_types (int): Default (0), disabled. How data types not natively supported
112+
by SQLite are looked up to be converted to Python types, using the converters
113+
registered with register_converter().
114+
Accepts any combination (using |, bitwise or) of PARSE_DECLTYPES and PARSE_COLNAMES.
115+
Column names takes precedence over declared types if both flags are set.
109116
110117
Returns:
111118
Connection: A DB-API 2.0 connection object representing the connection to the database.
@@ -122,13 +129,16 @@ def connect(
122129
else:
123130
config = SQLiteCloudConfig(connection_info)
124131

125-
return Connection(
126-
driver.connect(config.account.hostname, config.account.port, config)
132+
connection = Connection(
133+
driver.connect(config.account.hostname, config.account.port, config),
134+
detect_types=detect_types,
127135
)
128136

137+
return connection
138+
129139

130140
def register_adapter(
131-
pytype: Type, adapter_callable: Callable[[object], SQLiteTypes]
141+
pytype: Type, adapter_callable: Callable[[Any], SQLiteTypes]
132142
) -> None:
133143
"""
134144
Registers a callable to convert the type into one of the supported SQLite types.
@@ -138,8 +148,21 @@ def register_adapter(
138148
callable (Callable): The callable that converts the type into a supported
139149
SQLite supported type.
140150
"""
141-
global adapters
142-
adapters[pytype] = adapter_callable
151+
global _adapters
152+
_adapters[pytype] = adapter_callable
153+
154+
155+
def register_converter(type_name: str, converter: Callable[[bytes], Any]) -> None:
156+
"""
157+
Registers a callable to convert a bytestring from the database into a custom Python type.
158+
159+
Args:
160+
type_name (str): The name of the type to convert.
161+
The match with the name of the type in the query is case-insensitive.
162+
converter (Callable): The callable that converts the bytestring into the custom Python type.
163+
"""
164+
global _converters
165+
_converters[type_name.lower()] = converter
143166

144167

145168
class Connection:
@@ -154,16 +177,16 @@ class Connection:
154177
SQLiteCloud_connection (SQLiteCloudConnect): The SQLite Cloud connection object.
155178
"""
156179

157-
def __init__(self, sqlitecloud_connection: SQLiteCloudConnect) -> None:
180+
def __init__(
181+
self, sqlitecloud_connection: SQLiteCloudConnect, detect_types: int = 0
182+
) -> None:
158183
self._driver = Driver()
159184
self.sqlitecloud_connection = sqlitecloud_connection
160185

161186
self.row_factory: Optional[Callable[["Cursor", Tuple], object]] = None
162-
self.text_factory: Union[
163-
Type[Union[str, bytes]], Callable[[bytes], object]
164-
] = str
187+
self.text_factory: Union[Type[Union[str, bytes]], Callable[[bytes], Any]] = str
165188

166-
self.detect_types = 0
189+
self.detect_types = detect_types
167190

168191
@property
169192
def sqlcloud_connection(self) -> SQLiteCloudConnect:
@@ -273,19 +296,19 @@ def cursor(self):
273296
cursor.row_factory = self.row_factory
274297
return cursor
275298

276-
def _apply_adapter(self, value: object) -> SQLiteTypes:
299+
def _apply_adapter(self, value: Any) -> SQLiteTypes:
277300
"""
278301
Applies the registered adapter to convert the Python type into a SQLite supported type.
279302
In the case there is no registered adapter, it calls the __conform__() method when the value object implements it.
280303
281304
Args:
282-
value (object): The Python type to convert.
305+
value (Any): The Python type to convert.
283306
284307
Returns:
285308
SQLiteTypes: The SQLite supported type or the given value when no adapter is found.
286309
"""
287-
if type(value) in adapters:
288-
return adapters[type(value)](value)
310+
if type(value) in _adapters:
311+
return _adapters[type(value)](value)
289312

290313
if hasattr(value, "__conform__"):
291314
# we don't support sqlite3.PrepareProtocol
@@ -445,6 +468,8 @@ def executemany(
445468

446469
commands = ""
447470
for parameters in seq_of_parameters:
471+
parameters = self._adapt_parameters(parameters)
472+
448473
prepared_statement = self._driver.prepare_statement(sql, parameters)
449474
commands += prepared_statement + ";"
450475

@@ -547,24 +572,51 @@ def _adapt_parameters(self, parameters: Union[Dict, Tuple]) -> Union[Dict, Tuple
547572

548573
return tuple(self._connection._apply_adapter(p) for p in parameters)
549574

575+
def _convert_value(self, value: Any, decltype: Optional[str]) -> Any:
576+
# todo: parse columns first
577+
578+
if (self.connection.detect_types & PARSE_DECLTYPES) == PARSE_DECLTYPES:
579+
return self._parse_decltypes(value, decltype)
580+
581+
if decltype == SQLITECLOUD_VALUE_TYPE.TEXT.value or (
582+
decltype is None and isinstance(value, str)
583+
):
584+
return self._apply_text_factory(value)
585+
586+
return value
587+
588+
def _parse_decltypes(self, value: Any, decltype: str) -> Any:
589+
decltype = decltype.lower()
590+
if decltype in _converters:
591+
# sqlite3 always passes value as bytes
592+
value = (
593+
str(value).encode("utf-8") if not isinstance(value, bytes) else value
594+
)
595+
return _converters[decltype](value)
596+
597+
return value
598+
599+
def _apply_text_factory(self, value: Any) -> Any:
600+
"""Use Connection.text_factory to convert value with TEXT column or
601+
string value with undleclared column type."""
602+
603+
if self._connection.text_factory is bytes:
604+
return value.encode("utf-8")
605+
if self._connection.text_factory is not str and callable(
606+
self._connection.text_factory
607+
):
608+
return self._connection.text_factory(value.encode("utf-8"))
609+
610+
return value
611+
550612
def _get_value(self, row: int, col: int) -> Optional[Any]:
551613
if not self._is_result_rowset():
552614
return None
553615

554-
# Convert TEXT type with text_factory
616+
value = self._resultset.get_value(row, col)
555617
decltype = self._resultset.get_decltype(col)
556-
if decltype is None or decltype == SQLITECLOUD_VALUE_TYPE.TEXT.value:
557-
value = self._resultset.get_value(row, col, False)
558-
559-
if self._connection.text_factory is bytes:
560-
return value.encode("utf-8")
561-
if self._connection.text_factory is not str and callable(
562-
self._connection.text_factory
563-
):
564-
return self._connection.text_factory(value.encode("utf-8"))
565-
return value
566618

567-
return self._resultset.get_value(row, col)
619+
return self._convert_value(value, decltype)
568620

569621
def __iter__(self) -> "Cursor":
570622
return self
@@ -602,7 +654,7 @@ def adapt_datetime(val):
602654
return val.isoformat(" ")
603655

604656
def convert_date(val):
605-
return datetime.date(*map(int, val.split(b"-")))
657+
return date(*map(int, val.split(b"-")))
606658

607659
def convert_timestamp(val):
608660
datepart, timepart = val.split(b" ")
@@ -614,13 +666,13 @@ def convert_timestamp(val):
614666
else:
615667
microseconds = 0
616668

617-
val = datetime.datetime(year, month, day, hours, minutes, seconds, microseconds)
669+
val = datetime(year, month, day, hours, minutes, seconds, microseconds)
618670
return val
619671

620672
register_adapter(date, adapt_date)
621673
register_adapter(datetime, adapt_datetime)
622-
# register_converter("date", convert_date)
623-
# register_converter("timestamp", convert_timestamp)
674+
register_converter("date", convert_date)
675+
register_converter("timestamp", convert_timestamp)
624676

625677

626678
register_adapters_and_converters()

src/sqlitecloud/resultset.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,12 @@ def _compute_index(self, row: int, col: int) -> int:
6060
return -1
6161
return row * self.ncols + col
6262

63-
def get_value(self, row: int, col: int, convert: bool = True) -> Optional[any]:
63+
def get_value(self, row: int, col: int) -> Optional[any]:
6464
index = self._compute_index(row, col)
6565
if index < 0 or not self.data or index >= len(self.data):
6666
return None
6767

68-
value = self.data[index]
69-
return self._convert(value, col) if convert else value
68+
return self.data[index]
7069

7170
def get_name(self, col: int) -> Optional[str]:
7271
if col < 0 or col >= self.ncols:
@@ -79,23 +78,6 @@ def get_decltype(self, col: int) -> Optional[str]:
7978

8079
return self.decltype[col]
8180

82-
def _convert(self, value: str, col: int) -> any:
83-
if col < 0 or col >= len(self.decltype):
84-
return value
85-
86-
decltype = self.decltype[col]
87-
if decltype == SQLITECLOUD_VALUE_TYPE.INTEGER.value:
88-
return int(value)
89-
if decltype == SQLITECLOUD_VALUE_TYPE.FLOAT.value:
90-
return float(value)
91-
if decltype == SQLITECLOUD_VALUE_TYPE.BLOB.value:
92-
# values are received as bytes before being strings
93-
return bytes(value)
94-
if decltype == SQLITECLOUD_VALUE_TYPE.NULL.value:
95-
return None
96-
97-
return value
98-
9981

10082
class SQLiteCloudResultSet:
10183
def __init__(self, result: SQLiteCloudResult) -> None:

src/tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def sqlitecloud_dbapi2_connection():
4545
yield next(get_sqlitecloud_dbapi2_connection())
4646

4747

48-
def get_sqlitecloud_dbapi2_connection():
48+
def get_sqlitecloud_dbapi2_connection(detect_types: int = 0):
4949
account = SQLiteCloudAccount()
5050
account.username = os.getenv("SQLITE_USER")
5151
account.password = os.getenv("SQLITE_PASSWORD")
5252
account.dbname = os.getenv("SQLITE_DB")
5353
account.hostname = os.getenv("SQLITE_HOST")
5454
account.port = int(os.getenv("SQLITE_PORT"))
5555

56-
connection = sqlitecloud.connect(account)
56+
connection = sqlitecloud.connect(account, detect_types=detect_types)
5757

5858
assert isinstance(connection, sqlitecloud.Connection)
5959

@@ -62,12 +62,13 @@ def get_sqlitecloud_dbapi2_connection():
6262
connection.close()
6363

6464

65-
def get_sqlite3_connection():
65+
def get_sqlite3_connection(detect_types: int = 0):
6666
# set isolation_level=None to enable autocommit
6767
# and to be aligned with the behavior of SQLite Cloud
6868
connection = sqlite3.connect(
6969
os.path.join(os.path.dirname(__file__), "./assets/chinook.sqlite"),
7070
isolation_level=None,
71+
detect_types=detect_types,
7172
)
7273
yield connection
7374
connection.close()

src/tests/integration/test_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import random
23
import time
34

45
import pytest
@@ -641,10 +642,10 @@ def test_big_rowset(self):
641642

642643
connection = client.open_connection()
643644

644-
table_name = "TestCompress" + str(int(time.time()))
645+
table_name = "TestCompress" + str(random.randint(0, 99999))
645646
try:
646647
client.exec_query(
647-
f"CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY, name TEXT)",
648+
f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY, name TEXT)",
648649
connection,
649650
)
650651

@@ -663,7 +664,7 @@ def test_big_rowset(self):
663664

664665
assert rowset.nrows == nRows
665666
finally:
666-
client.exec_query(f"DROP TABLE {table_name}", connection)
667+
client.exec_query(f"DROP TABLE IF EXISTS {table_name}", connection)
667668
client.disconnect(connection)
668669

669670
def test_compression_single_column(self):

0 commit comments

Comments
 (0)