From b13b25bf6fae59b1e10f0f0f8a6077e84d9730de Mon Sep 17 00:00:00 2001 From: myyc Date: Wed, 18 Feb 2015 11:41:37 +0100 Subject: [PATCH] dropping sasl dependency, hooray pure-sasl! --- pyhs2/cloudera/__init__.py | 1 - pyhs2/cloudera/thrift_sasl.py | 191 ---------------------------------- pyhs2/connections.py | 73 ++++--------- pyhs2/twitter/__init__.py | 1 + pyhs2/twitter/thrift_sasl.py | 133 +++++++++++++++++++++++ setup.py | 4 +- 6 files changed, 157 insertions(+), 246 deletions(-) delete mode 100644 pyhs2/cloudera/__init__.py delete mode 100644 pyhs2/cloudera/thrift_sasl.py create mode 100644 pyhs2/twitter/__init__.py create mode 100644 pyhs2/twitter/thrift_sasl.py diff --git a/pyhs2/cloudera/__init__.py b/pyhs2/cloudera/__init__.py deleted file mode 100644 index c69dbc1..0000000 --- a/pyhs2/cloudera/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__all__ = ['tsasclienttransport'] diff --git a/pyhs2/cloudera/thrift_sasl.py b/pyhs2/cloudera/thrift_sasl.py deleted file mode 100644 index 37e4143..0000000 --- a/pyhs2/cloudera/thrift_sasl.py +++ /dev/null @@ -1,191 +0,0 @@ -#!/usr/bin/env python -# Licensed to Cloudera, Inc. under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -""" SASL transports for Thrift. """ - -from cStringIO import StringIO -from thrift.transport import TTransport -from thrift.transport.TTransport import * -from thrift.protocol import TBinaryProtocol -import sasl -import struct - -class TSaslClientTransport(TTransportBase, CReadableTransport): - START = 1 - OK = 2 - BAD = 3 - ERROR = 4 - COMPLETE = 5 - - def __init__(self, sasl_client_factory, mechanism, trans): - """ - @param sasl_client_factory: a callable that returns a new sasl.Client object - @param mechanism: the SASL mechanism (e.g. "GSSAPI") - @param trans: the underlying transport over which to communicate. - """ - self._trans = trans - self.sasl_client_factory = sasl_client_factory - self.sasl = None - self.mechanism = mechanism - self.__wbuf = StringIO() - self.__rbuf = StringIO() - self.opened = False - self.encode = None - - def isOpen(self): - return self._trans.isOpen() - - def open(self): - if not self._trans.isOpen(): - self._trans.open() - - if self.sasl is not None: - raise TTransportException( - type=TTransportException.NOT_OPEN, - message="Already open!") - self.sasl = self.sasl_client_factory - - ret, chosen_mech, initial_response = self.sasl.start(self.mechanism) - if not ret: - raise TTransportException(type=TTransportException.NOT_OPEN, - message=("Could not start SASL: %s" % self.sasl.getError())) - - # Send initial response - self._send_message(self.START, chosen_mech) - self._send_message(self.OK, initial_response) - - # SASL negotiation loop - while True: - status, payload = self._recv_sasl_message() - if status not in (self.OK, self.COMPLETE): - raise TTransportException(type=TTransportException.NOT_OPEN, - message=("Bad status: %d (%s)" % (status, payload))) - if status == self.COMPLETE: - break - ret, response = self.sasl.step(payload) - if not ret: - raise TTransportException(type=TTransportException.NOT_OPEN, - message=("Bad SASL result: %s" % (self.sasl.getError()))) - self._send_message(self.OK, response) - - def _send_message(self, status, body): - header = struct.pack(">BI", status, len(body)) - self._trans.write(header + body) - self._trans.flush() - - def _recv_sasl_message(self): - header = self._trans.readAll(5) - status, length = struct.unpack(">BI", header) - if length > 0: - payload = self._trans.readAll(length) - else: - payload = "" - return status, payload - - def write(self, data): - self.__wbuf.write(data) - - def flush(self): - buffer = self.__wbuf.getvalue() - # The first time we flush data, we send it to sasl.encode() - # If the length doesn't change, then we must be using a QOP - # of auth and we should no longer call sasl.encode(), otherwise - # we encode every time. - if self.encode == None: - success, encoded = self.sasl.encode(buffer) - if not success: - raise TTransportException(type=TTransportException.UNKNOWN, - message=self.sasl.getError()) - if (len(encoded)==len(buffer)): - self.encode = False - self._flushPlain(buffer) - else: - self.encode = True - self._trans.write(encoded) - elif self.encode: - self._flushEncoded(buffer) - else: - self._flushPlain(buffer) - - self._trans.flush() - self.__wbuf = StringIO() - - def _flushEncoded(self, buffer): - # sasl.ecnode() does the encoding and adds the length header, so nothing - # to do but call it and write the result. - success, encoded = self.sasl.encode(buffer) - if not success: - raise TTransportException(type=TTransportException.UNKNOWN, - message=self.sasl.getError()) - self._trans.write(encoded) - - def _flushPlain(self, buffer): - # When we have QOP of auth, sasl.encode() will pass the input to the output - # but won't put a length header, so we have to do that. - - # Note stolen from TFramedTransport: - # N.B.: Doing this string concatenation is WAY cheaper than making - # two separate calls to the underlying socket object. Socket writes in - # Python turn out to be REALLY expensive, but it seems to do a pretty - # good job of managing string buffer operations without excessive copies - self._trans.write(struct.pack(">I", len(buffer)) + buffer) - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret - - self._read_frame() - return self.__rbuf.read(sz) - - def _read_frame(self): - header = self._trans.readAll(4) - (length,) = struct.unpack(">I", header) - if self.encode: - # If the frames are encoded (i.e. you're using a QOP of auth-int or - # auth-conf), then make sure to include the header in the bytes you send to - # sasl.decode() - encoded = header + self._trans.readAll(length) - success, decoded = self.sasl.decode(encoded) - if not success: - raise TTransportException(type=TTransportException.UNKNOWN, - message=self.sasl.getError()) - else: - # If the frames are not encoded, just pass it through - decoded = self._trans.readAll(length) - self.__rbuf = StringIO(decoded) - - def close(self): - self._trans.close() - self.sasl = None - - # Implement the CReadableTransport interface. - # Stolen shamelessly from TFramedTransport - @property - def cstringio_buf(self): - return self.__rbuf - - def cstringio_refill(self, prefix, reqlen): - # self.__rbuf will already be empty here because fastbinary doesn't - # ask for a refill until the previous buffer is empty. Therefore, - # we can start reading new frames immediately. - while len(prefix) < reqlen: - self._read_frame() - prefix += self.__rbuf.getvalue() - self.__rbuf = StringIO(prefix) - return self.__rbuf diff --git a/pyhs2/connections.py b/pyhs2/connections.py index 2291479..6117731 100644 --- a/pyhs2/connections.py +++ b/pyhs2/connections.py @@ -1,55 +1,37 @@ -import sys - -from thrift.protocol.TBinaryProtocol import TBinaryProtocol -from thrift.transport.TSocket import TSocket from thrift.transport.TTransport import TBufferedTransport -import sasl -from cloudera.thrift_sasl import TSaslClientTransport - +from TCLIService.ttypes import TOpenSessionReq, TCloseSessionReq +from thrift import Thrift +from thrift.transport import TSocket, TTransport +from thrift.protocol.TBinaryProtocol import TBinaryProtocol from TCLIService import TCLIService - from cursor import Cursor -from TCLIService.ttypes import TCloseSessionReq,TOpenSessionReq +from twitter.thrift_sasl import TSaslClientTransport -class Connection(object): - DEFAULT_KRB_SERVICE = 'hive' - client = None - session = None - def __init__(self, host=None, port=10000, authMechanism=None, user=None, password=None, database=None, configuration=None, timeout=None): - authMechanisms = set(['NOSASL', 'PLAIN', 'KERBEROS', 'LDAP']) +class Connection(object): + def __init__(self, host=None, port=10000, authMechanism="PLAIN", username=None, password=None, database=None, + configuration=None, timeout=None): + authMechanisms = {"PLAIN", "NOSASL"} if authMechanism not in authMechanisms: - raise NotImplementedError('authMechanism is either not supported or not implemented') - #Must set a password for thrift, even if it doesn't need one - #Open issue with python-sasl - if authMechanism == 'PLAIN' and (password is None or len(password) == 0): - password = 'password' - socket = TSocket(host, port) - socket.setTimeout(timeout) - if authMechanism == 'NOSASL': - transport = TBufferedTransport(socket) - else: - sasl_mech = 'PLAIN' - saslc = sasl.Client() - saslc.setAttr("username", user) - saslc.setAttr("password", password) - if authMechanism == 'KERBEROS': - krb_host,krb_service = self._get_krb_settings(host, configuration) - sasl_mech = 'GSSAPI' - saslc.setAttr("host", krb_host) - saslc.setAttr("service", krb_service) + raise NotImplementedError("authMechanism '{}' is either not supported or not implemented".format(authMechanism)) - saslc.init() - transport = TSaslClientTransport(saslc, sasl_mech, socket) + socket = TSocket.TSocket(host, port) + socket.setTimeout(timeout) + if authMechanism == "NOSASL": + transport = TBufferedTransport(socket) + else: # authMechanism == "PLAIN": + password = "password" if (password is None or len(password) == 0) else password + transport = TSaslClientTransport(socket, host=host, service=None, mechanism=authMechanism, + username=username, password=password) self.client = TCLIService.Client(TBinaryProtocol(transport)) transport.open() - res = self.client.OpenSession(TOpenSessionReq(username=user, password=password, configuration=configuration)) + res = self.client.OpenSession(TOpenSessionReq(username=username, password=password, configuration=configuration)) self.session = res.sessionHandle if database is not None: with self.cursor() as cur: query = "USE {0}".format(database) - cur.execute(query) + cur.execute(query) def __enter__(self): return self @@ -57,22 +39,9 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_value, _traceback): self.close() - def _get_krb_settings(self, default_host, config): - host = default_host - service = self.DEFAULT_KRB_SERVICE - - if config is not None: - if 'krb_host' in config: - host = config['krb_host'] - - if 'krb_service' in config: - service = config['krb_service'] - - return host, service - def cursor(self): return Cursor(self.client, self.session) def close(self): req = TCloseSessionReq(sessionHandle=self.session) - self.client.CloseSession(req) \ No newline at end of file + self.client.CloseSession(req) diff --git a/pyhs2/twitter/__init__.py b/pyhs2/twitter/__init__.py new file mode 100644 index 0000000..fef449c --- /dev/null +++ b/pyhs2/twitter/__init__.py @@ -0,0 +1 @@ +__all__ = ["tsasclienttransport"] diff --git a/pyhs2/twitter/thrift_sasl.py b/pyhs2/twitter/thrift_sasl.py new file mode 100644 index 0000000..fcd2750 --- /dev/null +++ b/pyhs2/twitter/thrift_sasl.py @@ -0,0 +1,133 @@ +""" + A SASL Thrift transport based upon the pure-sasl library, both implemented by + @tylhobbs and adapted for twitter.common. See: + https://issues.apache.org/jira/browse/THRIFT-1719 + https://issues.apache.org/jira/secure/attachment/12548462/1719-python-sasl.txt +""" + +from struct import pack, unpack +from cStringIO import StringIO + +from puresasl.client import SASLClient +from thrift.transport.TTransport import ( + CReadableTransport, + TTransportBase, + TTransportException) + + +class TSaslClientTransport(TTransportBase, CReadableTransport): + """ + A SASL transport based on the pure-sasl library: + https://github.com/thobbs/pure-sasl + """ + + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + def __init__(self, transport, host, service, mechanism='GSSAPI', + **sasl_kwargs): + """ + transport: an underlying transport to use, typically just a TSocket + host: the name of the server, from a SASL perspective + service: the name of the server's service, from a SASL perspective + mechanism: the name of the preferred mechanism to use + All other kwargs will be passed to the puresasl.client.SASLClient + constructor. + """ + self.transport = transport + self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) + self.__wbuf = StringIO() + self.__rbuf = StringIO() + + # extremely awful hack, but you've got to do what you've got to do. + # essentially "wrap" and "unwrap" are defined for the base Mechanism class and raise a NotImplementedError by + # default, and PlainMechanism doesn't implement its own versions (lol). +# self.sasl._chosen_mech.wrap = lambda x: x +# self.sasl._chosen_mech.unwrap = lambda x: x + + def open(self): + if not self.transport.isOpen(): + self.transport.open() + + self.send_sasl_msg(self.START, self.sasl.mechanism) + self.send_sasl_msg(self.OK, self.sasl.process() or '') + + while True: + status, challenge = self.recv_sasl_msg() + if status == self.OK: + self.send_sasl_msg(self.OK, self.sasl.process(challenge) or '') + elif status == self.COMPLETE: + # self.sasl.complete is not set for PLAIN authentication (trollface.jpg) so we have to skip this check +# break + if not self.sasl.complete: + raise TTransportException("The server erroneously indicated " + "that SASL negotiation was complete") + else: + break + else: + raise TTransportException("Bad SASL negotiation status: %d (%s)" % (status, challenge)) + + def send_sasl_msg(self, status, body): + if body is None: + body = '' + header = pack(">BI", status, len(body)) + + body = body if isinstance(body, bytes) else body.encode("utf-8") + + self.transport.write(header + body) + self.transport.flush() + + def recv_sasl_msg(self): + header = self.transport.readAll(5) + status, length = unpack(">BI", header) + if length > 0: + payload = self.transport.readAll(length) + else: + payload = "" + return status, payload + + def write(self, data): + self.__wbuf.write(data) + + def flush(self): + data = self.__wbuf.getvalue() + encoded = self.sasl.wrap(data) + self.transport.write(''.join((pack("!i", len(encoded)), encoded))) + self.transport.flush() + self.__wbuf = StringIO() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self._read_frame() + return self.__rbuf.read(sz) + + def _read_frame(self): + header = self.transport.readAll(4) + length, = unpack('!i', header) + encoded = self.transport.readAll(length) + self.__rbuf = StringIO(self.sasl.unwrap(encoded)) + + def close(self): + self.sasl.dispose() + self.transport.close() + + # based on TFramedTransport + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self._read_frame() + prefix += self.__rbuf.getvalue() + self.__rbuf = StringIO(prefix) + return self.__rbuf diff --git a/setup.py b/setup.py index a0dfc90..ea96a52 100644 --- a/setup.py +++ b/setup.py @@ -5,13 +5,13 @@ version='0.6.0', author='Brad Ruderman', author_email='bradruderman@gmail.com', - packages=['pyhs2', 'pyhs2/cloudera', 'pyhs2/TCLIService'], + packages=['pyhs2', 'pyhs2/twitter', 'pyhs2/TCLIService'], url='https://github.com/BradRuderman/pyhs2', license='LICENSE.txt', description='Python Hive Server 2 Client Driver', long_description=open('README.md').read(), install_requires=[ - "sasl", + "pure-sasl>=0.1.7", "thrift", ], test_suite='pyhs2.test',