2020import binascii
2121import datetime
2222import math
23+ import time
2324import uuid
2425from decimal import Decimal
25- from typing import Any , Dict , List , NamedTuple , Optional # NOQA for mypy types
26+ from types import TracebackType
27+ from typing import Any , Dict , Iterator , List , NamedTuple , Optional , Sequence , Tuple , Type , Union
2628
2729import trino .client
2830import trino .exceptions
7274logger = trino .logging .get_logger (__name__ )
7375
7476
75- def connect (* args , ** kwargs ) :
77+ def connect (* args : Any , ** kwargs : Any ) -> trino . dbapi . Connection :
7678 """Constructor for creating a connection to the database.
7779
7880 See class :py:class:`Connection` for arguments.
@@ -92,28 +94,28 @@ class Connection(object):
9294
9395 def __init__ (
9496 self ,
95- host ,
96- port = constants .DEFAULT_PORT ,
97- user = None ,
98- source = constants .DEFAULT_SOURCE ,
99- catalog = constants .DEFAULT_CATALOG ,
100- schema = constants .DEFAULT_SCHEMA ,
101- session_properties = None ,
102- http_headers = None ,
103- http_scheme = constants .HTTP ,
104- auth = constants .DEFAULT_AUTH ,
105- extra_credential = None ,
106- redirect_handler = None ,
107- max_attempts = constants .DEFAULT_MAX_ATTEMPTS ,
108- request_timeout = constants .DEFAULT_REQUEST_TIMEOUT ,
109- isolation_level = IsolationLevel .AUTOCOMMIT ,
110- verify = True ,
111- http_session = None ,
112- client_tags = None ,
113- legacy_primitive_types = False ,
114- roles = None ,
97+ host : str ,
98+ port : int = constants .DEFAULT_PORT ,
99+ user : Optional [ str ] = None ,
100+ source : str = constants .DEFAULT_SOURCE ,
101+ catalog : Optional [ str ] = constants .DEFAULT_CATALOG ,
102+ schema : Optional [ str ] = constants .DEFAULT_SCHEMA ,
103+ session_properties : Optional [ Dict [ str , str ]] = None ,
104+ http_headers : Optional [ Dict [ str , str ]] = None ,
105+ http_scheme : str = constants .HTTP ,
106+ auth : Optional [ trino . auth . Authentication ] = constants .DEFAULT_AUTH ,
107+ extra_credential : Optional [ List [ Tuple [ str , str ]]] = None ,
108+ redirect_handler : Optional [ str ] = None ,
109+ max_attempts : int = constants .DEFAULT_MAX_ATTEMPTS ,
110+ request_timeout : float = constants .DEFAULT_REQUEST_TIMEOUT ,
111+ isolation_level : IsolationLevel = IsolationLevel .AUTOCOMMIT ,
112+ verify : Union [ bool | str ] = True ,
113+ http_session : Optional [ trino . client . TrinoRequest . http . Session ] = None ,
114+ client_tags : Optional [ List [ str ]] = None ,
115+ legacy_primitive_types : Optional [ bool ] = False ,
116+ roles : Optional [ Dict [ str , str ]] = None ,
115117 timezone = None ,
116- ):
118+ ) -> None :
117119 self .host = host
118120 self .port = port
119121 self .user = user
@@ -151,50 +153,53 @@ def __init__(
151153
152154 self ._isolation_level = isolation_level
153155 self ._request = None
154- self ._transaction = None
156+ self ._transaction : Optional [ Transaction ] = None
155157 self .legacy_primitive_types = legacy_primitive_types
156158
157159 @property
158- def isolation_level (self ):
160+ def isolation_level (self ) -> IsolationLevel :
159161 return self ._isolation_level
160162
161163 @property
162- def transaction (self ):
164+ def transaction (self ) -> Optional [ Transaction ] :
163165 return self ._transaction
164166
165- def __enter__ (self ):
167+ def __enter__ (self ) -> object :
166168 return self
167169
168- def __exit__ (self , exc_type , exc_value , traceback ):
170+ def __exit__ (self ,
171+ exc_type : Optional [Type [BaseException ]],
172+ exc_value : Optional [BaseException ],
173+ traceback : Optional [TracebackType ]) -> None :
169174 try :
170175 self .commit ()
171176 except Exception :
172177 self .rollback ()
173178 else :
174179 self .close ()
175180
176- def close (self ):
181+ def close (self ) -> None :
177182 # TODO cancel outstanding queries?
178183 self ._http_session .close ()
179184
180- def start_transaction (self ):
185+ def start_transaction (self ) -> Transaction :
181186 self ._transaction = Transaction (self ._create_request ())
182187 self ._transaction .begin ()
183188 return self ._transaction
184189
185- def commit (self ):
190+ def commit (self ) -> None :
186191 if self .transaction is None :
187192 return
188- self ._transaction .commit ()
193+ self .transaction .commit ()
189194 self ._transaction = None
190195
191- def rollback (self ):
196+ def rollback (self ) -> None :
192197 if self .transaction is None :
193198 raise RuntimeError ("no transaction was started" )
194- self ._transaction .rollback ()
199+ self .transaction .rollback ()
195200 self ._transaction = None
196201
197- def _create_request (self ):
202+ def _create_request (self ) -> trino . client . TrinoRequest :
198203 return trino .client .TrinoRequest (
199204 self .host ,
200205 self .port ,
@@ -207,7 +212,7 @@ def _create_request(self):
207212 self .request_timeout ,
208213 )
209214
210- def cursor (self , legacy_primitive_types : bool = None ):
215+ def cursor (self , legacy_primitive_types : bool = None ) -> 'trino.dbapi.Cursor' :
211216 """Return a new :py:class:`Cursor` object using the connection."""
212217 if self .isolation_level != IsolationLevel .AUTOCOMMIT :
213218 if self .transaction is None :
@@ -271,7 +276,10 @@ class Cursor(object):
271276
272277 """
273278
274- def __init__ (self , connection , request , legacy_primitive_types : bool = False ):
279+ def __init__ (self ,
280+ connection : Connection ,
281+ request : trino .client .TrinoRequest ,
282+ legacy_primitive_types : bool = False ) -> None :
275283 if not isinstance (connection , Connection ):
276284 raise ValueError (
277285 "connection must be a Connection object: {}" .format (type (connection ))
@@ -280,32 +288,32 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
280288 self ._request = request
281289
282290 self .arraysize = 1
283- self ._iterator = None
284- self ._query = None
291+ self ._iterator : Optional [ Iterator [ List [ Any ]]] = None
292+ self ._query : Optional [ trino . client . TrinoQuery ] = None
285293 self ._legacy_primitive_types = legacy_primitive_types
286294
287- def __iter__ (self ):
295+ def __iter__ (self ) -> Optional [ Iterator [ List [ Any ]]] :
288296 return self ._iterator
289297
290298 @property
291- def connection (self ):
299+ def connection (self ) -> Connection :
292300 return self ._connection
293301
294302 @property
295- def info_uri (self ):
303+ def info_uri (self ) -> Optional [ str ] :
296304 if self ._query is not None :
297305 return self ._query .info_uri
298306 return None
299307
300308 @property
301- def update_type (self ):
309+ def update_type (self ) -> Optional [ str ] :
302310 if self ._query is not None :
303311 return self ._query .update_type
304312 return None
305313
306314 @property
307- def description (self ) -> List [ColumnDescription ]:
308- if self ._query .columns is None :
315+ def description (self ) -> Optional [ List [Tuple [ Any , ...]] ]:
316+ if self ._query is None or self . _query .columns is None :
309317 return None
310318
311319 # [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
@@ -314,7 +322,7 @@ def description(self) -> List[ColumnDescription]:
314322 ]
315323
316324 @property
317- def rowcount (self ):
325+ def rowcount (self ) -> int :
318326 """Not supported.
319327
320328 Trino cannot reliablity determine the number of rows returned by an
@@ -325,27 +333,21 @@ def rowcount(self):
325333 return - 1
326334
327335 @property
328- def stats (self ):
336+ def stats (self ) -> Optional [ Dict [ Any , Any ]] :
329337 if self ._query is not None :
330338 return self ._query .stats
331339 return None
332340
333341 @property
334- def query_id (self ) -> Optional [str ]:
335- if self ._query is not None :
336- return self ._query .query_id
337- return None
338-
339- @property
340- def warnings (self ):
342+ def warnings (self ) -> Optional [List [Dict [Any , Any ]]]:
341343 if self ._query is not None :
342344 return self ._query .warnings
343345 return None
344346
345- def setinputsizes (self , sizes ) :
347+ def setinputsizes (self , sizes : Sequence [ Any ]) -> None :
346348 raise trino .exceptions .NotSupportedError
347349
348- def setoutputsize (self , size , column ) :
350+ def setoutputsize (self , size : int , column : Optional [ int ]) -> None :
349351 raise trino .exceptions .NotSupportedError
350352
351353 def _prepare_statement (self , statement : str , name : str ) -> None :
@@ -363,13 +365,13 @@ def _prepare_statement(self, statement: str, name: str) -> None:
363365
364366 def _execute_prepared_statement (
365367 self ,
366- statement_name ,
367- params
368- ):
368+ statement_name : str ,
369+ params : Any
370+ ) -> trino . client . TrinoQuery :
369371 sql = 'EXECUTE ' + statement_name + ' USING ' + ',' .join (map (self ._format_prepared_param , params ))
370372 return trino .client .TrinoQuery (self ._request , sql = sql , legacy_primitive_types = self ._legacy_primitive_types )
371373
372- def _format_prepared_param (self , param ) :
374+ def _format_prepared_param (self , param : Any ) -> str :
373375 """
374376 Formats parameters to be passed in an
375377 EXECUTE statement.
@@ -451,10 +453,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
451453 legacy_primitive_types = self ._legacy_primitive_types )
452454 query .execute ()
453455
454- def _generate_unique_statement_name (self ):
456+ def _generate_unique_statement_name (self ) -> str :
455457 return 'st_' + uuid .uuid4 ().hex .replace ('-' , '' )
456458
457- def execute (self , operation , params = None ):
459+ def execute (self , operation : str , params : Optional [ Any ] = None ) -> trino . client . TrinoResult :
458460 if params :
459461 assert isinstance (params , (list , tuple )), (
460462 'params must be a list or tuple containing the query '
@@ -484,7 +486,7 @@ def execute(self, operation, params=None):
484486 self ._iterator = iter (self ._query .execute ())
485487 return self
486488
487- def executemany (self , operation , seq_of_params ) :
489+ def executemany (self , operation : str , seq_of_params : Any ) -> None :
488490 """
489491 PEP-0249: Prepare a database operation (query or command) and then
490492 execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
@@ -529,7 +531,7 @@ def fetchone(self) -> Optional[List[Any]]:
529531 except trino .exceptions .HttpError as err :
530532 raise trino .exceptions .OperationalError (str (err ))
531533
532- def fetchmany (self , size = None ) -> List [List [Any ]]:
534+ def fetchmany (self , size : Optional [ int ] = None ) -> List [List [Any ]]:
533535 """
534536 PEP-0249: Fetch the next set of rows of a query result, returning a
535537 sequence of sequences (e.g. a list of tuples). An empty sequence is
@@ -584,20 +586,20 @@ def describe(self, sql: str) -> List[DescribeOutput]:
584586
585587 return list (map (lambda x : DescribeOutput .from_row (x ), result ))
586588
587- def genall (self ):
589+ def genall (self ) -> trino . client . TrinoResult :
588590 return self ._query .result
589591
590592 def fetchall (self ) -> List [List [Any ]]:
591593 return list (self .genall ())
592594
593- def cancel (self ):
595+ def cancel (self ) -> None :
594596 if self ._query is None :
595597 raise trino .exceptions .OperationalError (
596598 "Cancel query failed; no running query"
597599 )
598600 self ._query .cancel ()
599601
600- def close (self ):
602+ def close (self ) -> None :
601603 self .cancel ()
602604 # TODO: Cancel not only the last query executed on this cursor
603605 # but also any other outstanding queries executed through this cursor.
@@ -610,19 +612,19 @@ def close(self):
610612TimestampFromTicks = datetime .datetime .fromtimestamp
611613
612614
613- def TimeFromTicks (ticks ) :
614- return datetime .time (* datetime .localtime (ticks )[3 :6 ])
615+ def TimeFromTicks (ticks : int ) -> datetime . time :
616+ return datetime .time (* time .localtime (ticks )[3 :6 ])
615617
616618
617- def Binary (string ) :
619+ def Binary (string : str ) -> bytes :
618620 return string .encode ("utf-8" )
619621
620622
621623class DBAPITypeObject :
622- def __init__ (self , * values ):
624+ def __init__ (self , * values : str ):
623625 self .values = [v .lower () for v in values ]
624626
625- def __eq__ (self , other ) :
627+ def __eq__ (self , other : object ) -> bool :
626628 return other .lower () in self .values
627629
628630
0 commit comments