2222import math
2323import uuid
2424from decimal import Decimal
25- from typing import Any , Dict , List , NamedTuple , Optional # NOQA for mypy types
25+ from types import TracebackType
26+ from typing import Any , Dict , Iterator , List , NamedTuple , Optional , Sequence , Tuple , Type , Union
2627
2728import trino .client
2829import trino .exceptions
7273logger = trino .logging .get_logger (__name__ )
7374
7475
75- def connect (* args , ** kwargs ) :
76+ def connect (* args : Any , ** kwargs : Any ) -> trino . dbapi . Connection :
7677 """Constructor for creating a connection to the database.
7778
7879 See class :py:class:`Connection` for arguments.
@@ -92,28 +93,28 @@ class Connection(object):
9293
9394 def __init__ (
9495 self ,
95- host ,
96- port = constants .DEFAULT_PORT ,
97- user = None ,
98- source = constants .DEFAULT_SOURCE ,
99- catalog = constants .DEFAULT_CATALOG ,
100- schema = constants .DEFAULT_SCHEMA ,
101- session_properties = None ,
102- http_headers = None ,
103- http_scheme = constants .HTTP ,
104- auth = constants .DEFAULT_AUTH ,
105- extra_credential = None ,
106- redirect_handler = None ,
107- max_attempts = constants .DEFAULT_MAX_ATTEMPTS ,
108- request_timeout = constants .DEFAULT_REQUEST_TIMEOUT ,
109- isolation_level = IsolationLevel .AUTOCOMMIT ,
110- verify = True ,
111- http_session = None ,
112- client_tags = None ,
113- legacy_primitive_types = False ,
114- roles = None ,
96+ host : str ,
97+ port : int = constants .DEFAULT_PORT ,
98+ user : Optional [ str ] = None ,
99+ source : str = constants .DEFAULT_SOURCE ,
100+ catalog : Optional [ str ] = constants .DEFAULT_CATALOG ,
101+ schema : Optional [ str ] = constants .DEFAULT_SCHEMA ,
102+ session_properties : Optional [ Dict [ str , str ]] = None ,
103+ http_headers : Optional [ Dict [ str , str ]] = None ,
104+ http_scheme : str = constants .HTTP ,
105+ auth : Optional [ trino . auth . Authentication ] = constants .DEFAULT_AUTH ,
106+ extra_credential : Optional [ List [ Tuple [ str , str ]]] = None ,
107+ redirect_handler : Optional [ str ] = None ,
108+ max_attempts : int = constants .DEFAULT_MAX_ATTEMPTS ,
109+ request_timeout : float = constants .DEFAULT_REQUEST_TIMEOUT ,
110+ isolation_level : IsolationLevel = IsolationLevel .AUTOCOMMIT ,
111+ verify : Union [ bool | str ] = True ,
112+ http_session : Optional [ trino . client . TrinoRequest . http . Session ] = None ,
113+ client_tags : Optional [ List [ str ]] = None ,
114+ legacy_primitive_types : Optional [ bool ] = False ,
115+ roles : Optional [ Dict [ str , str ]] = None ,
115116 timezone = None ,
116- ):
117+ ) -> None :
117118 self .host = host
118119 self .port = port
119120 self .user = user
@@ -151,50 +152,53 @@ def __init__(
151152
152153 self ._isolation_level = isolation_level
153154 self ._request = None
154- self ._transaction = None
155+ self ._transaction : Optional [ Transaction ] = None
155156 self .legacy_primitive_types = legacy_primitive_types
156157
157158 @property
158- def isolation_level (self ):
159+ def isolation_level (self ) -> IsolationLevel :
159160 return self ._isolation_level
160161
161162 @property
162- def transaction (self ):
163+ def transaction (self ) -> Optional [ Transaction ] :
163164 return self ._transaction
164165
165- def __enter__ (self ):
166+ def __enter__ (self ) -> object :
166167 return self
167168
168- def __exit__ (self , exc_type , exc_value , traceback ):
169+ def __exit__ (self ,
170+ exc_type : Optional [Type [BaseException ]],
171+ exc_value : Optional [BaseException ],
172+ traceback : Optional [TracebackType ]) -> None :
169173 try :
170174 self .commit ()
171175 except Exception :
172176 self .rollback ()
173177 else :
174178 self .close ()
175179
176- def close (self ):
180+ def close (self ) -> None :
177181 # TODO cancel outstanding queries?
178182 self ._http_session .close ()
179183
180- def start_transaction (self ):
184+ def start_transaction (self ) -> Transaction :
181185 self ._transaction = Transaction (self ._create_request ())
182186 self ._transaction .begin ()
183187 return self ._transaction
184188
185- def commit (self ):
186- if self .transaction is None :
189+ def commit (self ) -> None :
190+ if self ._transaction is None :
187191 return
188192 self ._transaction .commit ()
189193 self ._transaction = None
190194
191- def rollback (self ):
192- if self .transaction is None :
195+ def rollback (self ) -> None :
196+ if self ._transaction is None :
193197 raise RuntimeError ("no transaction was started" )
194198 self ._transaction .rollback ()
195199 self ._transaction = None
196200
197- def _create_request (self ):
201+ def _create_request (self ) -> trino . client . TrinoRequest :
198202 return trino .client .TrinoRequest (
199203 self .host ,
200204 self .port ,
@@ -207,7 +211,7 @@ def _create_request(self):
207211 self .request_timeout ,
208212 )
209213
210- def cursor (self , legacy_primitive_types : bool = None ):
214+ def cursor (self , legacy_primitive_types : bool = None ) -> 'trino.dbapi.Cursor' :
211215 """Return a new :py:class:`Cursor` object using the connection."""
212216 if self .isolation_level != IsolationLevel .AUTOCOMMIT :
213217 if self .transaction is None :
@@ -271,7 +275,10 @@ class Cursor(object):
271275
272276 """
273277
274- def __init__ (self , connection , request , legacy_primitive_types : bool = False ):
278+ def __init__ (self ,
279+ connection : Connection ,
280+ request : trino .client .TrinoRequest ,
281+ legacy_primitive_types : bool = False ) -> None :
275282 if not isinstance (connection , Connection ):
276283 raise ValueError (
277284 "connection must be a Connection object: {}" .format (type (connection ))
@@ -280,32 +287,32 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
280287 self ._request = request
281288
282289 self .arraysize = 1
283- self ._iterator = None
284- self ._query = None
290+ self ._iterator : Optional [ Iterator [ Any ]] = None
291+ self ._query : Optional [ trino . client . TrinoQuery ] = None
285292 self ._legacy_primitive_types = legacy_primitive_types
286293
287- def __iter__ (self ):
294+ def __iter__ (self ) -> Optional [ Iterator [ Any ]] :
288295 return self ._iterator
289296
290297 @property
291- def connection (self ):
298+ def connection (self ) -> Connection :
292299 return self ._connection
293300
294301 @property
295- def info_uri (self ):
302+ def info_uri (self ) -> Optional [ str ] :
296303 if self ._query is not None :
297304 return self ._query .info_uri
298305 return None
299306
300307 @property
301- def update_type (self ):
308+ def update_type (self ) -> Optional [ str ] :
302309 if self ._query is not None :
303310 return self ._query .update_type
304311 return None
305312
306313 @property
307- def description (self ) -> List [ColumnDescription ]:
308- if self ._query .columns is None :
314+ def description (self ) -> Optional [ List [Tuple [ Any , ...]] ]:
315+ if self ._query is None or self . _query .columns is None :
309316 return None
310317
311318 # [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
@@ -314,7 +321,7 @@ def description(self) -> List[ColumnDescription]:
314321 ]
315322
316323 @property
317- def rowcount (self ):
324+ def rowcount (self ) -> int :
318325 """Not supported.
319326
320327 Trino cannot reliablity determine the number of rows returned by an
@@ -325,27 +332,21 @@ def rowcount(self):
325332 return - 1
326333
327334 @property
328- def stats (self ):
335+ def stats (self ) -> Optional [ Dict [ Any , Any ]] :
329336 if self ._query is not None :
330337 return self ._query .stats
331338 return None
332339
333340 @property
334- def query_id (self ) -> Optional [str ]:
335- if self ._query is not None :
336- return self ._query .query_id
337- return None
338-
339- @property
340- def warnings (self ):
341+ def warnings (self ) -> Optional [List [Dict [Any , Any ]]]:
341342 if self ._query is not None :
342343 return self ._query .warnings
343344 return None
344345
345- def setinputsizes (self , sizes ) :
346+ def setinputsizes (self , sizes : Sequence [ Any ]) -> None :
346347 raise trino .exceptions .NotSupportedError
347348
348- def setoutputsize (self , size , column ) :
349+ def setoutputsize (self , size : int , column : Optional [ int ]) -> None :
349350 raise trino .exceptions .NotSupportedError
350351
351352 def _prepare_statement (self , statement : str , name : str ) -> None :
@@ -363,13 +364,13 @@ def _prepare_statement(self, statement: str, name: str) -> None:
363364
364365 def _execute_prepared_statement (
365366 self ,
366- statement_name ,
367- params
368- ):
367+ statement_name : str ,
368+ params : Any
369+ ) -> trino . client . TrinoQuery :
369370 sql = 'EXECUTE ' + statement_name + ' USING ' + ',' .join (map (self ._format_prepared_param , params ))
370371 return trino .client .TrinoQuery (self ._request , sql = sql , legacy_primitive_types = self ._legacy_primitive_types )
371372
372- def _format_prepared_param (self , param ) :
373+ def _format_prepared_param (self , param : Any ) -> str :
373374 """
374375 Formats parameters to be passed in an
375376 EXECUTE statement.
@@ -451,10 +452,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
451452 legacy_primitive_types = self ._legacy_primitive_types )
452453 query .execute ()
453454
454- def _generate_unique_statement_name (self ):
455+ def _generate_unique_statement_name (self ) -> str :
455456 return 'st_' + uuid .uuid4 ().hex .replace ('-' , '' )
456457
457- def execute (self , operation , params = None ):
458+ def execute (self , operation : str , params : Optional [ Any ] = None ) -> trino . client . TrinoResult :
458459 if params :
459460 assert isinstance (params , (list , tuple )), (
460461 'params must be a list or tuple containing the query '
@@ -484,7 +485,7 @@ def execute(self, operation, params=None):
484485 self ._iterator = iter (self ._query .execute ())
485486 return self
486487
487- def executemany (self , operation , seq_of_params ) :
488+ def executemany (self , operation : str , seq_of_params : Any ) -> None :
488489 """
489490 PEP-0249: Prepare a database operation (query or command) and then
490491 execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
@@ -503,6 +504,7 @@ def executemany(self, operation, seq_of_params):
503504 for parameters in seq_of_params [:- 1 ]:
504505 self .execute (operation , parameters )
505506 self .fetchall ()
507+ assert self ._query is not None
506508 if self ._query .update_type is None :
507509 raise NotSupportedError ("Query must return update type" )
508510 if seq_of_params :
@@ -529,7 +531,7 @@ def fetchone(self) -> Optional[List[Any]]:
529531 except trino .exceptions .HttpError as err :
530532 raise trino .exceptions .OperationalError (str (err ))
531533
532- def fetchmany (self , size = None ) -> List [List [Any ]]:
534+ def fetchmany (self , size : Optional [ int ] = None ) -> List [List [Any ]]:
533535 """
534536 PEP-0249: Fetch the next set of rows of a query result, returning a
535537 sequence of sequences (e.g. a list of tuples). An empty sequence is
@@ -562,6 +564,7 @@ def fetchmany(self, size=None) -> List[List[Any]]:
562564
563565 return result
564566
567+ < << << << HEAD
565568 def describe (self , sql : str ) -> List [DescribeOutput ]:
566569 """
567570 List the output columns of a SQL statement, including the column name (or alias), catalog, schema, table, type,
@@ -584,66 +587,20 @@ def describe(self, sql: str) -> List[DescribeOutput]:
584587
585588 return list (map (lambda x : DescribeOutput .from_row (x ), result ))
586589
587- def genall (self ):
590+ def genall (self ) -> trino . client . TrinoResult :
588591 return self ._query .result
589592
590593 def fetchall (self ) -> List [List [Any ]]:
591594 return list (self .genall ())
592595
593- def cancel (self ):
596+ def cancel (self ) -> None :
594597 if self ._query is None :
595598 raise trino .exceptions .OperationalError (
596599 "Cancel query failed; no running query"
597600 )
598601 self ._query .cancel ()
599602
600- def close (self ):
603+ def close (self ) -> None :
601604 self .cancel ()
602605 # TODO: Cancel not only the last query executed on this cursor
603606 # but also any other outstanding queries executed through this cursor.
604-
605-
606- Date = datetime .date
607- Time = datetime .time
608- Timestamp = datetime .datetime
609- DateFromTicks = datetime .date .fromtimestamp
610- TimestampFromTicks = datetime .datetime .fromtimestamp
611-
612-
613- def TimeFromTicks (ticks ):
614- return datetime .time (* datetime .localtime (ticks )[3 :6 ])
615-
616-
617- def Binary (string ):
618- return string .encode ("utf-8" )
619-
620-
621- class DBAPITypeObject :
622- def __init__ (self , * values ):
623- self .values = [v .lower () for v in values ]
624-
625- def __eq__ (self , other ):
626- return other .lower () in self .values
627-
628-
629- STRING = DBAPITypeObject ("VARCHAR" , "CHAR" , "VARBINARY" , "JSON" , "IPADDRESS" )
630-
631- BINARY = DBAPITypeObject (
632- "ARRAY" , "MAP" , "ROW" , "HyperLogLog" , "P4HyperLogLog" , "QDigest"
633- )
634-
635- NUMBER = DBAPITypeObject (
636- "BOOLEAN" , "TINYINT" , "SMALLINT" , "INTEGER" , "BIGINT" , "REAL" , "DOUBLE" , "DECIMAL"
637- )
638-
639- DATETIME = DBAPITypeObject (
640- "DATE" ,
641- "TIME" ,
642- "TIME WITH TIME ZONE" ,
643- "TIMESTAMP" ,
644- "TIMESTAMP WITH TIME ZONE" ,
645- "INTERVAL YEAR TO MONTH" ,
646- "INTERVAL DAY TO SECOND" ,
647- )
648-
649- ROWID = DBAPITypeObject () # nothing indicates row id in Trino
0 commit comments