11#!/usr/bin/env python
22from __future__ import annotations
33
4+ import abc
45import collections
56import logging
67import os
1920 TYPE_CHECKING ,
2021 Any ,
2122 Callable ,
23+ Dict ,
24+ Generic ,
2225 Iterator ,
2326 Literal ,
2427 NamedTuple ,
2528 NoReturn ,
2629 Sequence ,
30+ Tuple ,
2731 TypeVar ,
32+ Union ,
2833 overload ,
2934)
3035
8691 from .result_batch import ResultBatch
8792
8893T = TypeVar ("T" , bound = collections .abc .Sequence )
94+ FetchRow = TypeVar ("FetchRow" , bound = Union [Tuple [Any , ...], Dict [str , Any ]])
8995
9096logger = getLogger (__name__ )
9197
@@ -332,29 +338,7 @@ class ResultState(Enum):
332338 RESET = 3
333339
334340
335- class SnowflakeCursor :
336- """Implementation of Cursor object that is returned from Connection.cursor() method.
337-
338- Attributes:
339- description: A list of namedtuples about metadata for all columns.
340- rowcount: The number of records updated or selected. If not clear, -1 is returned.
341- rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
342- determined.
343- sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
344- sqlstate: Snowflake SQL State code.
345- timestamp_output_format: Snowflake timestamp_output_format for timestamps.
346- timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
347- timestamp_tz_output_format: Snowflake output format for TZ timestamps.
348- timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
349- date_output_format: Snowflake output format for dates.
350- time_output_format: Snowflake output format for times.
351- timezone: Snowflake timezone.
352- binary_output_format: Snowflake output format for binary fields.
353- arraysize: The default number of rows fetched by fetchmany.
354- connection: The connection object by which the cursor was created.
355- errorhandle: The class that handles error handling.
356- is_file_transfer: Whether, or not the current command is a put, or get.
357- """
341+ class SnowflakeCursorBase (abc .ABC , Generic [FetchRow ]):
358342
359343 # TODO:
360344 # Most of these attributes have no reason to be properties, we could just store them in public variables.
@@ -382,13 +366,11 @@ def get_file_transfer_type(sql: str) -> FileTransferType | None:
382366 def __init__ (
383367 self ,
384368 connection : SnowflakeConnection ,
385- use_dict_result : bool = False ,
386369 ) -> None :
387370 """Inits a SnowflakeCursor with a connection.
388371
389372 Args:
390373 connection: The connection that created this cursor.
391- use_dict_result: Decides whether to use dict result or not.
392374 """
393375 self ._connection : SnowflakeConnection = connection
394376
@@ -423,7 +405,6 @@ def __init__(
423405 self ._result : Iterator [tuple ] | Iterator [dict ] | None = None
424406 self ._result_set : ResultSet | None = None
425407 self ._result_state : ResultState = ResultState .DEFAULT
426- self ._use_dict_result = use_dict_result
427408 self .query : str | None = None
428409 # TODO: self._query_result_format could be defined as an enum
429410 self ._query_result_format : str | None = None
@@ -435,7 +416,7 @@ def __init__(
435416 self ._first_chunk_time = None
436417
437418 self ._log_max_query_length = connection .log_max_query_length
438- self ._inner_cursor : SnowflakeCursor | None = None
419+ self ._inner_cursor : SnowflakeCursorBase | None = None
439420 self ._prefetch_hook = None
440421 self ._rownumber : int | None = None
441422
@@ -448,6 +429,12 @@ def __del__(self) -> None: # pragma: no cover
448429 if logger .getEffectiveLevel () <= logging .INFO :
449430 logger .info (e )
450431
432+ @property
433+ @abc .abstractmethod
434+ def _use_dict_result (self ) -> bool :
435+ """Decides whether results from helper functions are returned as a dict."""
436+ pass
437+
451438 @property
452439 def description (self ) -> list [ResultMetadata ]:
453440 if self ._description is None :
@@ -1514,8 +1501,17 @@ def executemany(
15141501
15151502 return self
15161503
1517- def fetchone (self ) -> dict | tuple | None :
1518- """Fetches one row."""
1504+ @abc .abstractmethod
1505+ def fetchone (self ) -> FetchRow :
1506+ pass
1507+
1508+ def _fetchone (self ) -> dict [str , Any ] | tuple [Any , ...] | None :
1509+ """
1510+ Fetches one row.
1511+
1512+ Returns a dict if self._use_dict_result is True, otherwise
1513+ returns tuple.
1514+ """
15191515 if self ._prefetch_hook is not None :
15201516 self ._prefetch_hook ()
15211517 if self ._result is None and self ._result_set is not None :
@@ -1539,7 +1535,7 @@ def fetchone(self) -> dict | tuple | None:
15391535 else :
15401536 return None
15411537
1542- def fetchmany (self , size : int | None = None ) -> list [tuple ] | list [ dict ]:
1538+ def fetchmany (self , size : int | None = None ) -> list [FetchRow ]:
15431539 """Fetches the number of specified rows."""
15441540 if size is None :
15451541 size = self .arraysize
@@ -1565,7 +1561,7 @@ def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
15651561
15661562 return ret
15671563
1568- def fetchall (self ) -> list [tuple ] | list [ dict ]:
1564+ def fetchall (self ) -> list [FetchRow ]:
15691565 """Fetches all of the results."""
15701566 ret = []
15711567 while True :
@@ -1728,20 +1724,31 @@ def wait_until_ready() -> None:
17281724 # Unset this function, so that we don't block anymore
17291725 self ._prefetch_hook = None
17301726
1731- if (
1732- self ._inner_cursor ._total_rowcount == 1
1733- and self ._inner_cursor .fetchall ()
1734- == [("Multiple statements executed successfully." ,)]
1727+ if self ._inner_cursor ._total_rowcount == 1 and _is_successful_multi_stmt (
1728+ self ._inner_cursor .fetchall ()
17351729 ):
17361730 url = f"/queries/{ sfqid } /result"
17371731 ret = self ._connection .rest .request (url = url , method = "get" )
17381732 if "data" in ret and "resultIds" in ret ["data" ]:
17391733 self ._init_multi_statement_results (ret ["data" ])
17401734
1735+ def _is_successful_multi_stmt (rows : list [Any ]) -> bool :
1736+ if len (rows ) != 1 :
1737+ return False
1738+ row = rows [0 ]
1739+ if isinstance (row , tuple ):
1740+ return row == ("Multiple statements executed successfully." ,)
1741+ elif isinstance (row , dict ):
1742+ return row == {
1743+ "multiple statement execution" : "Multiple statements executed successfully."
1744+ }
1745+ else :
1746+ return False
1747+
17411748 self .connection .get_query_status_throw_if_error (
17421749 sfqid
17431750 ) # Trigger an exception if query failed
1744- self ._inner_cursor = SnowflakeCursor (self .connection )
1751+ self ._inner_cursor = self . __class__ (self .connection )
17451752 self ._sfqid = sfqid
17461753 self ._prefetch_hook = wait_until_ready
17471754
@@ -1925,14 +1932,53 @@ def _create_file_transfer_agent(
19251932 )
19261933
19271934
1928- class DictCursor (SnowflakeCursor ):
1935+ class SnowflakeCursor (SnowflakeCursorBase [tuple [Any , ...]]):
1936+ """Implementation of Cursor object that is returned from Connection.cursor() method.
1937+
1938+ Attributes:
1939+ description: A list of namedtuples about metadata for all columns.
1940+ rowcount: The number of records updated or selected. If not clear, -1 is returned.
1941+ rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
1942+ determined.
1943+ sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
1944+ sqlstate: Snowflake SQL State code.
1945+ timestamp_output_format: Snowflake timestamp_output_format for timestamps.
1946+ timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
1947+ timestamp_tz_output_format: Snowflake output format for TZ timestamps.
1948+ timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
1949+ date_output_format: Snowflake output format for dates.
1950+ time_output_format: Snowflake output format for times.
1951+ timezone: Snowflake timezone.
1952+ binary_output_format: Snowflake output format for binary fields.
1953+ arraysize: The default number of rows fetched by fetchmany.
1954+ connection: The connection object by which the cursor was created.
1955+ errorhandle: The class that handles error handling.
1956+ is_file_transfer: Whether, or not the current command is a put, or get.
1957+ """
1958+
1959+ @property
1960+ def _use_dict_result (self ) -> bool :
1961+ return False
1962+
1963+ def fetchone (self ) -> tuple [Any , ...] | None :
1964+ row = self ._fetchone ()
1965+ if not (row is None or isinstance (row , tuple )):
1966+ raise TypeError (f"fetchone got unexpected result: { row } " )
1967+ return row
1968+
1969+
1970+ class DictCursor (SnowflakeCursorBase [dict [str , Any ]]):
19291971 """Cursor returning results in a dictionary."""
19301972
1931- def __init__ (self , connection ) -> None :
1932- super ().__init__ (
1933- connection ,
1934- use_dict_result = True ,
1935- )
1973+ @property
1974+ def _use_dict_result (self ) -> bool :
1975+ return True
1976+
1977+ def fetchone (self ) -> dict [str , Any ] | None :
1978+ row = self ._fetchone ()
1979+ if not (row is None or isinstance (row , dict )):
1980+ raise TypeError (f"fetchone got unexpected result: { row } " )
1981+ return row
19361982
19371983
19381984def __getattr__ (name ):
0 commit comments