4747from datetime import date , datetime , time , timedelta , timezone , tzinfo
4848from decimal import Decimal
4949from time import sleep
50- from typing import Any , Dict , Generic , List , Optional , Tuple , TypeVar , Union
50+ from typing import (
51+ Any ,
52+ Callable ,
53+ Dict ,
54+ Generator ,
55+ Generic ,
56+ List ,
57+ Optional ,
58+ Tuple ,
59+ Type ,
60+ TypeVar ,
61+ Union ,
62+ )
5163
5264import pytz
5365import requests
5466from pytz .tzinfo import BaseTzInfo
55- from tzlocal import get_localzone_name # type: ignore
67+ from tzlocal import get_localzone_name
5668
5769import trino .logging
5870from trino import constants , exceptions
5971
6072try :
61- from zoneinfo import ZoneInfo # type: ignore
73+ from zoneinfo import ZoneInfo
6274
6375except ModuleNotFoundError :
6476 from backports .zoneinfo import ZoneInfo # type: ignore
7587else :
7688 PROXIES = {}
7789
78- _HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re .compile (r' ^\S[^\s=]*$' )
90+ _HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re .compile (r" ^\S[^\s=]*$" )
7991
8092T = TypeVar ("T" )
8193
@@ -461,8 +473,13 @@ def http_headers(self) -> Dict[str, str]:
461473 "{}={}" .format (catalog , urllib .parse .quote (str (role )))
462474 for catalog , role in self ._client_session .roles .items ()
463475 )
464- if self ._client_session .client_tags is not None and len (self ._client_session .client_tags ) > 0 :
465- headers [constants .HEADER_CLIENT_TAGS ] = "," .join (self ._client_session .client_tags )
476+ if (
477+ self ._client_session .client_tags is not None
478+ and len (self ._client_session .client_tags ) > 0
479+ ):
480+ headers [constants .HEADER_CLIENT_TAGS ] = "," .join (
481+ self ._client_session .client_tags
482+ )
466483
467484 headers [constants .HEADER_SESSION ] = "," .join (
468485 # ``name`` must not contain ``=``
@@ -486,18 +503,23 @@ def http_headers(self) -> Dict[str, str]:
486503 transaction_id = self ._client_session .transaction_id
487504 headers [constants .HEADER_TRANSACTION ] = transaction_id
488505
489- if self ._client_session .extra_credential is not None and \
490- len (self ._client_session .extra_credential ) > 0 :
506+ if (
507+ self ._client_session .extra_credential is not None
508+ and len (self ._client_session .extra_credential ) > 0
509+ ):
491510
492511 for tup in self ._client_session .extra_credential :
493512 self ._verify_extra_credential (tup )
494513
495514 # HTTP 1.1 section 4.2 combine multiple extra credentials into a
496515 # comma-separated value
497516 # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
498- headers [constants .HEADER_EXTRA_CREDENTIAL ] = \
499- ", " .join (
500- [f"{ tup [0 ]} ={ urllib .parse .quote_plus (tup [1 ])} " for tup in self ._client_session .extra_credential ])
517+ headers [constants .HEADER_EXTRA_CREDENTIAL ] = ", " .join (
518+ [
519+ f"{ tup [0 ]} ={ urllib .parse .quote_plus (tup [1 ])} "
520+ for tup in self ._client_session .extra_credential
521+ ]
522+ )
501523
502524 return headers
503525
@@ -562,7 +584,12 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non
562584 while http_response is not None and http_response .is_redirect :
563585 location = http_response .headers ["Location" ]
564586 url = self ._redirect_handler .handle (location )
565- logger .info ("redirect %s from %s to %s" , http_response .status_code , location , url )
587+ logger .info (
588+ "redirect %s from %s to %s" ,
589+ http_response .status_code ,
590+ location ,
591+ url ,
592+ )
566593 http_response = self ._post (
567594 url ,
568595 data = data ,
@@ -606,7 +633,7 @@ def raise_response_error(self, http_response):
606633 raise exceptions .HttpError (
607634 "error {}{}" .format (
608635 http_response .status_code ,
609- ": {}" .format (http_response .content ) if http_response .content else "" ,
636+ ": {}" .format (repr ( http_response .content ) ) if http_response .content else "" ,
610637 )
611638 )
612639
@@ -633,14 +660,18 @@ def process(self, http_response) -> TrinoStatus:
633660 self ._client_session .properties [key ] = value
634661
635662 if constants .HEADER_SET_CATALOG in http_response .headers :
636- self ._client_session .catalog = http_response .headers [constants .HEADER_SET_CATALOG ]
663+ self ._client_session .catalog = http_response .headers [
664+ constants .HEADER_SET_CATALOG
665+ ]
637666
638667 if constants .HEADER_SET_SCHEMA in http_response .headers :
639- self ._client_session .schema = http_response .headers [constants .HEADER_SET_SCHEMA ]
668+ self ._client_session .schema = http_response .headers [
669+ constants .HEADER_SET_SCHEMA
670+ ]
640671
641672 if constants .HEADER_SET_ROLE in http_response .headers :
642673 for key , value in get_roles_values (
643- http_response .headers , constants .HEADER_SET_ROLE
674+ http_response .headers , constants .HEADER_SET_ROLE
644675 ):
645676 self ._client_session .roles [key ] = value
646677
@@ -676,12 +707,16 @@ def _verify_extra_credential(self, header):
676707 key = header [0 ]
677708
678709 if not _HEADER_EXTRA_CREDENTIAL_KEY_REGEX .match (key ):
679- raise ValueError (f"whitespace or '=' are disallowed in extra credential '{ key } '" )
710+ raise ValueError (
711+ f"whitespace or '=' are disallowed in extra credential '{ key } '"
712+ )
680713
681714 try :
682- key .encode ().decode (' ascii' )
715+ key .encode ().decode (" ascii" )
683716 except UnicodeDecodeError :
684- raise ValueError (f"only ASCII characters are allowed in extra credential '{ key } '" )
717+ raise ValueError (
718+ f"only ASCII characters are allowed in extra credential '{ key } '"
719+ )
685720
686721
687722class TrinoResult (object ):
@@ -847,7 +882,10 @@ def cancel(self) -> None:
847882
848883 def is_finished (self ) -> bool :
849884 import warnings
850- warnings .warn ("is_finished is deprecated, use finished instead" , DeprecationWarning )
885+
886+ warnings .warn (
887+ "is_finished is deprecated, use finished instead" , DeprecationWarning
888+ )
851889 return self .finished
852890
853891 @property
@@ -910,11 +948,11 @@ class DoubleValueMapper(ValueMapper[float]):
910948 def map (self , value ) -> Optional [float ]:
911949 if value is None :
912950 return None
913- if value == ' Infinity' :
951+ if value == " Infinity" :
914952 return float ("inf" )
915- if value == ' -Infinity' :
953+ if value == " -Infinity" :
916954 return float ("-inf" )
917- if value == ' NaN' :
955+ if value == " NaN" :
918956 return float ("nan" )
919957 return float (value )
920958
@@ -1119,7 +1157,9 @@ def __init__(self, mappers: List[ValueMapper[Any]]):
11191157 def map (self , values : List [Any ]) -> Optional [Tuple [Optional [Any ], ...]]:
11201158 if values is None :
11211159 return None
1122- return tuple (self .mappers [index ].map (value ) for index , value in enumerate (values ))
1160+ return tuple (
1161+ self .mappers [index ].map (value ) for index , value in enumerate (values )
1162+ )
11231163
11241164
11251165class MapValueMapper (ValueMapper [Dict [Any , Optional [Any ]]]):
@@ -1131,7 +1171,8 @@ def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
11311171 if values is None :
11321172 return None
11331173 return {
1134- self .key_mapper .map (key ): self .value_mapper .map (value ) for key , value in values .items ()
1174+ self .key_mapper .map (key ): self .value_mapper .map (value )
1175+ for key , value in values .items ()
11351176 }
11361177
11371178
@@ -1151,6 +1192,7 @@ class RowMapperFactory:
11511192 lambda functions (one for each column) which will process a data value
11521193 and returns a RowMapper instance which will process rows of data
11531194 """
1195+
11541196 NO_OP_ROW_MAPPER = NoOpRowMapper ()
11551197
11561198 def create (self , columns , legacy_primitive_types ):
@@ -1163,19 +1205,22 @@ def create(self, columns, legacy_primitive_types):
11631205 def _create_value_mapper (self , column ) -> ValueMapper :
11641206 col_type = column ['rawType' ]
11651207
1166- if col_type == ' array' :
1167- value_mapper = self ._create_value_mapper (column [' arguments' ][0 ][' value' ])
1208+ if col_type == " array" :
1209+ value_mapper = self ._create_value_mapper (column [" arguments" ][0 ][" value" ])
11681210 return ArrayValueMapper (value_mapper )
1169- elif col_type == 'row' :
1170- mappers = [self ._create_value_mapper (arg ['value' ]['typeSignature' ]) for arg in column ['arguments' ]]
1211+ elif col_type == "row" :
1212+ mappers = [
1213+ self ._create_value_mapper (arg ["value" ]["typeSignature" ])
1214+ for arg in column ["arguments" ]
1215+ ]
11711216 return RowValueMapper (mappers )
1172- elif col_type == ' map' :
1173- key_mapper = self ._create_value_mapper (column [' arguments' ][0 ][' value' ])
1174- value_mapper = self ._create_value_mapper (column [' arguments' ][1 ][' value' ])
1217+ elif col_type == " map" :
1218+ key_mapper = self ._create_value_mapper (column [" arguments" ][0 ][" value" ])
1219+ value_mapper = self ._create_value_mapper (column [" arguments" ][1 ][" value" ])
11751220 return MapValueMapper (key_mapper , value_mapper )
1176- elif col_type .startswith (' decimal' ):
1221+ elif col_type .startswith (" decimal" ):
11771222 return DecimalValueMapper ()
1178- elif col_type .startswith (' double' ) or col_type .startswith (' real' ):
1223+ elif col_type .startswith (" double" ) or col_type .startswith (" real" ):
11791224 return DoubleValueMapper ()
11801225 elif col_type .startswith ('timestamp' ) and 'with time zone' in col_type :
11811226 return TimestampWithTimeZoneValueMapper (self ._get_precision (column ))
0 commit comments