1- import operator
2- import csv
3- import six
41import codecs
2+ import csv
3+ import operator
54import os .path
65import re
6+
7+ import prettytable
8+ import six
79import sqlalchemy
810import sqlparse
9- import prettytable
11+
12+ from .column_guesser import ColumnGuesserMixin
13+
1014try :
1115 from pgspecial .main import PGSpecial
1216except ImportError :
1317 PGSpecial = None
14- from .column_guesser import ColumnGuesserMixin
1518
1619
1720def unduplicate_field_names (field_names ):
@@ -26,6 +29,7 @@ def unduplicate_field_names(field_names):
2629 res .append (k )
2730 return res
2831
32+
2933class UnicodeWriter (object ):
3034 """
3135 A CSV writer which will write rows to CSV file "f",
@@ -41,19 +45,17 @@ def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds):
4145
4246 def writerow (self , row ):
4347 if six .PY2 :
44- _row = [s .encode ("utf-8" )
45- if hasattr (s , "encode" )
46- else s
48+ _row = [s .encode ("utf-8" ) if hasattr (s , "encode" ) else s
4749 for s in row ]
4850 else :
4951 _row = row
5052 self .writer .writerow (_row )
5153 # Fetch UTF-8 output from the queue ...
5254 data = self .queue .getvalue ()
5355 if six .PY2 :
54- data = data .decode ("utf-8" )
55- # ... and reencode it into the target encoding
56- data = self .encoder .encode (data )
56+ data = data .decode ("utf-8" )
57+ # ... and reencode it into the target encoding
58+ data = self .encoder .encode (data )
5759 # write to the target stream
5860 self .stream .write (data )
5961 # empty queue
@@ -64,14 +66,20 @@ def writerows(self, rows):
6466 for row in rows :
6567 self .writerow (row )
6668
69+
6770class CsvResultDescriptor (object ):
6871 """Provides IPython Notebook-friendly output for the feedback after a ``.csv`` called."""
72+
6973 def __init__ (self , file_path ):
7074 self .file_path = file_path
75+
7176 def __repr__ (self ):
72- return 'CSV results at %s' % os .path .join (os .path .abspath ('.' ), self .file_path )
77+ return 'CSV results at %s' % os .path .join (
78+ os .path .abspath ('.' ), self .file_path )
79+
7380 def _repr_html_ (self ):
74- return '<a href="%s">CSV results</a>' % os .path .join ('.' , 'files' , self .file_path )
81+ return '<a href="%s">CSV results</a>' % os .path .join ('.' , 'files' ,
82+ self .file_path )
7583
7684
7785def _nonbreaking_spaces (match_obj ):
@@ -84,6 +92,7 @@ def _nonbreaking_spaces(match_obj):
8492 spaces = ' ' * len (match_obj .group (2 ))
8593 return '%s%s' % (match_obj .group (1 ), spaces )
8694
95+
8796_cell_with_spaces_pattern = re .compile (r'(<td>)( {2,})' )
8897
8998
@@ -93,6 +102,7 @@ class ResultSet(list, ColumnGuesserMixin):
93102
94103 Can access rows listwise, or by string value of leftmost column.
95104 """
105+
96106 def __init__ (self , sqlaproxy , sql , config ):
97107 self .keys = sqlaproxy .keys ()
98108 self .sql = sql
@@ -118,7 +128,8 @@ def _repr_html_(self):
118128 self .pretty .add_rows (self )
119129 result = self .pretty .get_html_string ()
120130 result = _cell_with_spaces_pattern .sub (_nonbreaking_spaces , result )
121- if self .config .displaylimit and len (self ) > self .config .displaylimit :
131+ if self .config .displaylimit and len (
132+ self ) > self .config .displaylimit :
122133 result = '%s\n <span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>' % (
123134 result , len (self ), self .config .displaylimit )
124135 return result
@@ -143,6 +154,7 @@ def __getitem__(self, key):
143154 if len (result ) > 1 :
144155 raise KeyError ('%d results for "%s"' % (len (result ), key ))
145156 return result [0 ]
157+
146158 def dict (self ):
147159 """Returns a single dict built from the result set
148160
@@ -217,7 +229,7 @@ def plot(self, title=None, **kwargs):
217229 plt .ylabel (ylabel )
218230 return plot
219231
220- def bar (self , key_word_sep = " " , title = None , ** kwargs ):
232+ def bar (self , key_word_sep = " " , title = None , ** kwargs ):
221233 """Generates a pylab bar plot from the result set.
222234
223235 ``matplotlib`` must be installed, and in an
@@ -241,8 +253,7 @@ def bar(self, key_word_sep = " ", title=None, **kwargs):
241253 self .guess_pie_columns (xlabel_sep = key_word_sep )
242254 plot = plt .bar (range (len (self .ys [0 ])), self .ys [0 ], ** kwargs )
243255 if self .xlabels :
244- plt .xticks (range (len (self .xlabels )), self .xlabels ,
245- rotation = 45 )
256+ plt .xticks (range (len (self .xlabels )), self .xlabels , rotation = 45 )
246257 plt .xlabel (self .xlabel )
247258 plt .ylabel (self .ys [0 ].name )
248259 return plot
@@ -251,7 +262,7 @@ def csv(self, filename=None, **format_params):
251262 """Generate results in comma-separated form. Write to ``filename`` if given.
252263 Any other parameters will be passed on to csv.writer."""
253264 if not self .pretty :
254- return None # no results
265+ return None # no results
255266 self .pretty .add_rows (self )
256267 if filename :
257268 encoding = format_params .get ('encoding' , 'utf-8' )
@@ -279,17 +290,37 @@ def interpret_rowcount(rowcount):
279290 result = '%d rows affected.' % rowcount
280291 return result
281292
293+
282294class FakeResultProxy (object ):
283295 """A fake class that pretends to behave like the ResultProxy from
284296 SqlAlchemy.
285297 """
298+
286299 def __init__ (self , cursor , headers ):
287300 self .fetchall = cursor .fetchall
288301 self .fetchmany = cursor .fetchmany
289302 self .rowcount = cursor .rowcount
290303 self .keys = lambda : headers
291304 self .returns_rows = True
292305
306+ # some dialects have autocommit
307+ # specific dialects break when commit is used:
308+ _COMMIT_BLACKLIST_DIALECTS = ('mssql' , 'clickhouse' )
309+
310+
311+ def _commit (conn , config ):
312+ """Issues a commit, if appropriate for current config and dialect"""
313+
314+ _should_commit = config .autocommit and all (
315+ dialect not in str (conn .dialect )
316+ for dialect in _COMMIT_BLACKLIST_DIALECTS )
317+
318+ if _should_commit :
319+ try :
320+ conn .session .execute ('commit' )
321+ except sqlalchemy .exc .OperationalError :
322+ pass # not all engines can commit
323+
293324
294325def run (conn , sql , config , user_namespace ):
295326 if sql .strip ():
@@ -302,21 +333,12 @@ def run(conn, sql, config, user_namespace):
302333 raise ImportError ('pgspecial not installed' )
303334 pgspecial = PGSpecial ()
304335 _ , cur , headers , _ = pgspecial .execute (
305- conn .session .connection .cursor (),
306- statement )[0 ]
336+ conn .session .connection .cursor (), statement )[0 ]
307337 result = FakeResultProxy (cur , headers )
308338 else :
309339 txt = sqlalchemy .sql .text (statement )
310340 result = conn .session .execute (txt , user_namespace )
311- try :
312- # some dialects have autocommit
313- # specific dialects break when commit is used:
314- dialects_blacklist = ('mssql' , 'clickhouse' )
315- if config .autocommit \
316- and all (dialect not in str (conn .dialect ) for dialect in dialects_blacklist ):
317- conn .session .execute ('commit' )
318- except sqlalchemy .exc .OperationalError :
319- pass # not all engines can commit
341+ _commit (conn = conn , config = config )
320342 if result and config .feedback :
321343 print (interpret_rowcount (result .rowcount ))
322344 resultset = ResultSet (result , statement , config )
@@ -330,11 +352,10 @@ def run(conn, sql, config, user_namespace):
330352
331353
332354class PrettyTable (prettytable .PrettyTable ):
333-
334355 def __init__ (self , * args , ** kwargs ):
335356 self .row_count = 0
336357 self .displaylimit = None
337- return super (PrettyTable , self ).__init__ (* args , ** kwargs )
358+ return super (PrettyTable , self ).__init__ (* args , ** kwargs )
338359
339360 def add_rows (self , data ):
340361 if self .row_count and (data .config .displaylimit == self .displaylimit ):
0 commit comments