Skip to content

Commit 93e268b

Browse files
committed
feat(row-object): support to Row object for row_factory
1 parent 9b72f71 commit 93e268b

File tree

4 files changed

+133
-14
lines changed

4 files changed

+133
-14
lines changed

src/sqlitecloud/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
PARSE_DECLTYPES,
88
Connection,
99
Cursor,
10+
Row,
1011
adapters,
1112
connect,
1213
converters,
@@ -25,6 +26,7 @@
2526
"PARSE_COLNAMES",
2627
"adapters",
2728
"converters",
29+
"Row",
2830
]
2931

3032
VERSION = "0.0.79"

src/sqlitecloud/dbapi2.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,15 +341,14 @@ class Cursor(Iterator[Any]):
341341

342342
arraysize: int = 1
343343

344-
row_factory: Optional[Callable[["Cursor", Tuple], object]] = None
345-
346344
def __init__(self, connection: Connection) -> None:
347345
self._driver = Driver()
348-
self.row_factory = None
349346
self._connection = connection
350347
self._iter_row: int = 0
351348
self._resultset: SQLiteCloudResult = None
352349

350+
self.row_factory: Optional[Callable[["Cursor", Tuple], object]] = None
351+
353352
@property
354353
def connection(self) -> Connection:
355354
"""
@@ -577,6 +576,9 @@ def _call_row_factory(self, row: Tuple) -> object:
577576
if self.row_factory is None:
578577
return row
579578

579+
if self.row_factory is Row:
580+
return Row(row, [col[0] for col in self.description])
581+
580582
return self.row_factory(self, row)
581583

582584
def _is_result_rowset(self) -> bool:
@@ -697,6 +699,59 @@ def __next__(self) -> Optional[Tuple[Any]]:
697699
raise StopIteration
698700

699701

702+
class Row:
703+
def __init__(self, data: Tuple[Any], column_names: List[str]):
704+
"""
705+
Initialize the Row object with data and column names.
706+
707+
Args:
708+
data (Tuple[Any]): A tuple containing the row data.
709+
column_names (List[str]): A list of column names corresponding to the data.
710+
"""
711+
self._data = data
712+
self._column_names = column_names
713+
self._column_map = {name.lower(): idx for idx, name in enumerate(column_names)}
714+
715+
def keys(self) -> List[str]:
716+
"""Return the column names."""
717+
return self._column_names
718+
719+
def __getitem__(self, key):
720+
"""Support indexing by both column name and index."""
721+
if isinstance(key, int):
722+
return self._data[key]
723+
elif isinstance(key, str):
724+
return self._data[self._column_map[key.lower()]]
725+
else:
726+
raise TypeError("Invalid key type. Must be int or str.")
727+
728+
def __len__(self) -> int:
729+
return len(self._data)
730+
731+
def __iter__(self) -> Iterator[Any]:
732+
return iter(self._data)
733+
734+
def __repr__(self) -> str:
735+
return "\n".join(
736+
f"{name}: {self._data[idx]}" for idx, name in enumerate(self._column_names)
737+
)
738+
739+
def __hash__(self) -> int:
740+
return hash((self._data, tuple(self._column_map)))
741+
742+
def __eq__(self, other) -> bool:
743+
"""Check if both have the same data and column names."""
744+
if not isinstance(other, Row):
745+
return NotImplemented
746+
747+
return self._data == other._data and self._column_map == other._column_map
748+
749+
def __ne__(self, other):
750+
if not isinstance(other, Row):
751+
return NotImplemented
752+
return not self.__eq__(other)
753+
754+
700755
class MissingDecltypeException(Exception):
701756
def __init__(self, message: str) -> None:
702757
super().__init__(message)

src/tests/integration/test_dbapi2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,16 @@ def test_row_factory(self, sqlitecloud_dbapi2_connection):
246246
assert row["AlbumId"] == 1
247247
assert row["Title"] == "For Those About To Rock We Salute You"
248248
assert row["ArtistId"] == 1
249+
250+
def test_row_object_for_factory_string_representation(
251+
self, sqlitecloud_dbapi2_connection
252+
):
253+
connection = sqlitecloud_dbapi2_connection
254+
255+
connection.row_factory = sqlitecloud.Row
256+
257+
cursor = connection.execute('SELECT "foo" as Bar, "john" Doe')
258+
259+
row = cursor.fetchone()
260+
261+
assert str(row) == "Bar: foo\nDoe: john"

src/tests/integration/test_sqlite3_parity.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,27 +131,76 @@ def test_close_cursor_raises_exception(
131131
with pytest.raises(sqlite3.ProgrammingError) as e:
132132
sqlite3_cursor.fetchall()
133133

134-
def test_row_factory(self, sqlitecloud_dbapi2_connection, sqlite3_connection):
135-
sqlitecloud_connection = sqlitecloud_dbapi2_connection
134+
@pytest.mark.parametrize(
135+
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
136+
)
137+
def test_row_factory(self, connection, request):
138+
connection = request.getfixturevalue(connection)
136139

137140
def simple_factory(cursor, row):
138141
return {
139142
description[0]: row[i]
140143
for i, description in enumerate(cursor.description)
141144
}
142145

143-
sqlitecloud_connection.row_factory = simple_factory
144-
sqlite3_connection.row_factory = simple_factory
146+
connection.row_factory = simple_factory
145147

146-
select_query = "SELECT * FROM albums WHERE AlbumId = 1"
147-
sqlitecloud_cursor = sqlitecloud_connection.execute(select_query)
148-
sqlite3_cursor = sqlite3_connection.execute(select_query)
148+
select_query = "SELECT AlbumId, Title, ArtistId FROM albums WHERE AlbumId = 1"
149+
cursor = connection.execute(select_query)
149150

150-
sqlitecloud_results = sqlitecloud_cursor.fetchall()
151-
sqlite3_results = sqlite3_cursor.fetchall()
151+
results = cursor.fetchall()
152152

153-
assert sqlitecloud_results == sqlite3_results
154-
assert sqlitecloud_results[0]["Title"] == sqlite3_results[0]["Title"]
153+
assert results[0]["AlbumId"] == 1
154+
assert results[0]["Title"] == "For Those About To Rock We Salute You"
155+
assert results[0]["ArtistId"] == 1
156+
assert connection.row_factory == cursor.row_factory
157+
158+
@pytest.mark.parametrize(
159+
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
160+
)
161+
def test_cursor_row_factory_as_instance_variable(self, connection, request):
162+
connection = request.getfixturevalue(connection)
163+
164+
cursor = connection.execute("SELECT 1")
165+
cursor.row_factory = lambda c, r: list(r)
166+
167+
cursor2 = connection.execute("SELECT 1")
168+
169+
assert cursor.row_factory != cursor2.row_factory
170+
171+
@pytest.mark.parametrize(
172+
"connection, module",
173+
[
174+
("sqlitecloud_dbapi2_connection", sqlitecloud),
175+
("sqlite3_connection", sqlite3),
176+
],
177+
)
178+
def test_row_factory_with_row_object(self, connection, module, request):
179+
connection = request.getfixturevalue(connection)
180+
181+
connection.row_factory = module.Row
182+
183+
select_query = "SELECT AlbumId, Title, ArtistId FROM albums WHERE AlbumId = 1"
184+
cursor = connection.execute(select_query)
185+
186+
row = cursor.fetchone()
187+
188+
assert row["AlbumId"] == 1
189+
assert row["Title"] == "For Those About To Rock We Salute You"
190+
assert row[1] == row["Title"]
191+
assert row["Title"] == row["title"]
192+
assert row.keys() == ["AlbumId", "Title", "ArtistId"]
193+
assert len(row) == 3
194+
assert next(iter(row)) == 1 # AlbumId
195+
assert not row != row
196+
assert row == row
197+
198+
cursor = connection.execute(
199+
"SELECT AlbumId, Title, ArtistId FROM albums WHERE AlbumId = 2"
200+
)
201+
other_row = cursor.fetchone()
202+
203+
assert row != other_row
155204

156205
@pytest.mark.parametrize(
157206
"connection",

0 commit comments

Comments
 (0)