From e2b32af057db63879733d6f52531ba3274b3d1c0 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:23:22 -0500 Subject: [PATCH 01/46] fix DeprecationWarning emitted from _cli_main --- sshtunnel.py | 67 +++++++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index a7db0c4..58a1b28 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1881,47 +1881,56 @@ def _parse_arguments(args=None): def _cli_main(args=None, **extras): - """ Pass input arguments to open_tunnel - - Mandatory: ssh_address, -R (remote bind address list) - - Optional: - -U (username) we may gather it from SSH_CONFIG_FILE or current username - -p (server_port), defaults to 22 - -P (password) - -L (local_bind_address), default to 0.0.0.0:22 - -k (ssh_host_key) - -K (private_key_file), may be gathered from SSH_CONFIG_FILE - -S (private_key_password) - -t (threaded), allow concurrent connections over tunnels - -v (verbose), up to 3 (-vvv) to raise loglevel from ERROR to DEBUG - -V (version) - -x (proxy), ProxyCommand's IP:PORT, may be gathered from config file - -c (ssh_config), ssh configuration file (defaults to SSH_CONFIG_FILE) - -z (compress) - -n (noagent), disable looking for keys from an Agent - -d (host_pkey_directories), look for keys on these folders + """Pass input arguments to open_tunnel + + Mandatory: ssh_address, -R (remote bind address list) + + Optional: + -U (username) we may gather it from SSH_CONFIG_FILE or current username + -p (server_port), defaults to 22 + -P (password) + -L (local_bind_address), default to 0.0.0.0:22 + -k (ssh_host_key) + -K (private_key_file), may be gathered from SSH_CONFIG_FILE + -S (private_key_password) + -t (threaded), allow concurrent connections over tunnels + -v (verbose), up to 3 (-vvv) to raise loglevel from ERROR to DEBUG + -V (version) + -x (proxy), ProxyCommand's IP:PORT, may be gathered from config file + -c (ssh_config), ssh configuration file (defaults to SSH_CONFIG_FILE) + -z (compress) + -n (noagent), disable looking for keys from an Agent + -d (host_pkey_directories), look for keys on these folders """ arguments = _parse_arguments(args) + # Remove all "None" input values _remove_none_values(arguments) + + for old_key in ['ssh_address', 'ssh_host']: + if old_key in arguments: + arguments['ssh_address_or_host'] = arguments.pop(old_key) + verbosity = min(arguments.pop('verbose'), 4) - levels = [logging.ERROR, - logging.WARNING, - logging.INFO, - logging.DEBUG, - TRACE_LEVEL] + levels = [ + logging.ERROR, + logging.WARNING, + logging.INFO, + logging.DEBUG, + TRACE_LEVEL, + ] arguments.setdefault('debug_level', levels[verbosity]) - # do this while supporting py27/py34 instead of merging dicts - for (extra, value) in extras.items(): + + # do this while supporting py27 instead of merging dicts + for extra, value in extras.items(): arguments.setdefault(extra, value) with open_tunnel(**arguments) as tunnel: if tunnel.is_alive: - input_(''' + input_(""" Press or to stop! - ''') + """) if __name__ == '__main__': # pragma: no cover From 5b360f2717414fb5e17a6061568dadfe6a18e903 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:25:46 -0500 Subject: [PATCH 02/46] descriptive noqa comments --- sshtunnel.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 58a1b28..781437e 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -29,8 +29,9 @@ if sys.version_info[0] < 3: # pragma: no cover import Queue as queue import SocketServer as socketserver - string_types = basestring, # noqa - input_ = raw_input # noqa + + string_types = basestring # noqa: F821 undefined name + input_ = raw_input # noqa: F821 undefined name else: # pragma: no cover import queue import socketserver @@ -1061,10 +1062,11 @@ def get_agent_keys(logger=None): return list(agent_keys) @staticmethod - def get_keys(logger=None, host_pkey_directories=None, allow_agent=False): + def get_keys( # noqa: C901 too complex + logger=None, host_pkey_directories=None, allow_agent=False + ): """ - Load public keys from any available SSH agent or local - .ssh directory. + Load public keys from any available SSH agent or local .ssh directory. Arguments: logger (Optional[logging.Logger]) @@ -1874,8 +1876,7 @@ def _parse_arguments(args=None): nargs='*', dest='host_pkey_directories', metavar='FOLDER', - help='List of directories where SSH pkeys (in the format `id_*`) ' - 'may be found' + help='List of directories where SSH pkeys (in the format `id_*`) may be found', # noqa: E501 line too long ) return vars(parser.parse_args(args)) From 300f783d98e235057be610336f7cf7503442dc75 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:26:27 -0500 Subject: [PATCH 03/46] sort functions and methods --- sshtunnel.py | 882 ++++++++++++++++++++-------------------- tests/test_forwarder.py | 192 ++++----- 2 files changed, 538 insertions(+), 536 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 781437e..e9d3c66 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -157,6 +157,39 @@ def check_addresses(address_list, is_remote=False): check_address(address) +def _add_handler(logger, handler=None, loglevel=None): + """ + Add a handler to an existing logging.Logger object + """ + handler.setLevel(loglevel or DEFAULT_LOGLEVEL) + if handler.level <= logging.DEBUG: + _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' \ + '%(lineno)04d@%(module)-10.9s| %(message)s' + handler.setFormatter(logging.Formatter(_fmt)) + else: + handler.setFormatter(logging.Formatter( + '%(asctime)s| %(levelname)-8s| %(message)s' + )) + logger.addHandler(handler) + + +def _check_paramiko_handlers(logger=None): + """ + Add a console handler for paramiko.transport's logger if not present + """ + paramiko_logger = logging.getLogger('paramiko.transport') + if not paramiko_logger.handlers: + if logger: + paramiko_logger.handlers = logger.handlers + else: + console_handler = logging.StreamHandler() + console_handler.setFormatter( + logging.Formatter('%(asctime)s | %(levelname)-8s| PARAMIKO: ' + '%(lineno)03d@%(module)-10s| %(message)s') + ) + paramiko_logger.addHandler(console_handler) + + def create_logger(logger=None, loglevel=None, capture_warnings=True, @@ -216,39 +249,6 @@ def create_logger(logger=None, return logger -def _add_handler(logger, handler=None, loglevel=None): - """ - Add a handler to an existing logging.Logger object - """ - handler.setLevel(loglevel or DEFAULT_LOGLEVEL) - if handler.level <= logging.DEBUG: - _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' \ - '%(lineno)04d@%(module)-10.9s| %(message)s' - handler.setFormatter(logging.Formatter(_fmt)) - else: - handler.setFormatter(logging.Formatter( - '%(asctime)s| %(levelname)-8s| %(message)s' - )) - logger.addHandler(handler) - - -def _check_paramiko_handlers(logger=None): - """ - Add a console handler for paramiko.transport's logger if not present - """ - paramiko_logger = logging.getLogger('paramiko.transport') - if not paramiko_logger.handlers: - if logger: - paramiko_logger.handlers = logger.handlers - else: - console_handler = logging.StreamHandler() - console_handler.setFormatter( - logging.Formatter('%(asctime)s | %(levelname)-8s| PARAMIKO: ' - '%(lineno)03d@%(module)-10s| %(message)s') - ) - paramiko_logger.addHandler(console_handler) - - def address_to_str(address): if isinstance(address, tuple): return '{0[0]}:{0[1]}'.format(address) @@ -748,137 +748,165 @@ class SSHTunnelForwarder(object): # This option affect only `Transport` thread daemon_transport = _DAEMON #: flag SSH transport thread in daemon mode - def local_is_up(self, target): - """ - Check if a tunnel is up (remote target's host is reachable on TCP - target's port) - - Arguments: - target (tuple): - tuple of type (``str``, ``int``) indicating the listen IP - address and port - Return: - boolean - - .. deprecated:: 0.1.0 - Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up` - """ - try: - check_address(target) - except ValueError: - self.logger.warning('Target must be a tuple (IP, port), where IP ' - 'is a string (i.e. "192.168.0.1") and port is ' - 'an integer (i.e. 40000). Alternatively ' - 'target can be a valid UNIX domain socket.') - return False - - self.check_tunnels() - return self.tunnel_is_up.get(target, True) - - def check_tunnels(self): + @staticmethod + def _read_ssh_config(ssh_host, + ssh_config_file, + ssh_username=None, + ssh_pkey=None, + ssh_port=None, + ssh_proxy=None, + compression=None, + logger=None): """ - Check that if all tunnels are established and populates - :attr:`.tunnel_is_up` + Read ssh_config_file and tries to look for user (ssh_username), + identityfile (ssh_pkey), port (ssh_port) and proxycommand + (ssh_proxy) entries for ssh_host """ - skip_tunnel_checkup = self.skip_tunnel_checkup - try: - # force tunnel check at this point - self.skip_tunnel_checkup = False - for _srv in self._server_list: - self._check_tunnel(_srv) - finally: - self.skip_tunnel_checkup = skip_tunnel_checkup # roll it back + ssh_config = paramiko.SSHConfig() + if not ssh_config_file: # handle case where it's an empty string + ssh_config_file = None - def _check_tunnel(self, _srv): - """ Check if tunnel is already established """ - if self.skip_tunnel_checkup: - self.tunnel_is_up[_srv.local_address] = True - return - self.logger.info('Checking tunnel to: {0}'.format(_srv.remote_address)) - if isinstance(_srv.local_address, string_types): # UNIX stream - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - else: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.settimeout(TUNNEL_TIMEOUT) + # Try to read SSH_CONFIG_FILE try: - # Windows raises WinError 10049 if trying to connect to 0.0.0.0 - connect_to = ('127.0.0.1', _srv.local_port) \ - if _srv.local_host == '0.0.0.0' else _srv.local_address - s.connect(connect_to) - self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get( - timeout=TUNNEL_TIMEOUT * 1.1 - ) - self.logger.debug( - 'Tunnel to {0} is DOWN'.format(_srv.remote_address) + # open the ssh config file + with open(os.path.expanduser(ssh_config_file), 'r') as f: + ssh_config.parse(f) + # looks for information for the destination system + hostname_info = ssh_config.lookup(ssh_host) + # gather settings for user, port and identity file + # last resort: use the 'login name' of the user + ssh_username = ( + ssh_username or + hostname_info.get('user') ) - except socket.error: - self.logger.debug( - 'Tunnel to {0} is DOWN'.format(_srv.remote_address) + ssh_pkey = ( + ssh_pkey or + hostname_info.get('identityfile', [None])[0] ) - self.tunnel_is_up[_srv.local_address] = False + ssh_host = hostname_info.get('hostname') + ssh_port = ssh_port or hostname_info.get('port') - except queue.Empty: - self.logger.debug( - 'Tunnel to {0} is UP'.format(_srv.remote_address) - ) - self.tunnel_is_up[_srv.local_address] = True + proxycommand = hostname_info.get('proxycommand') + ssh_proxy = ssh_proxy or (paramiko.ProxyCommand(proxycommand) if + proxycommand else None) + if compression is None: + compression = hostname_info.get('compression', '') + compression = True if compression.upper() == 'YES' else False + except IOError: + if logger: + logger.warning( + 'Could not read SSH configuration file: {0}' + .format(ssh_config_file) + ) + except (AttributeError, TypeError): # ssh_config_file is None + if logger: + logger.info('Skipping loading of ssh configuration file') finally: - s.close() + return (ssh_host, + ssh_username or getpass.getuser(), + ssh_pkey, + int(ssh_port) if ssh_port else 22, # fallback value + ssh_proxy, + compression) - def _make_ssh_forward_handler_class(self, remote_address_): + @staticmethod + def _consolidate_binds(local_binds, remote_binds): """ - Make SSH Handler class + Fill local_binds with defaults when no value/s were specified, + leaving paramiko to decide in which local port the tunnel will be open """ - class Handler(_ForwardHandler): - remote_address = remote_address_ - ssh_transport = self._transport - logger = self.logger - return Handler - - def _make_ssh_forward_server_class(self, remote_address_): - return _ThreadingForwardServer if self._threaded else _ForwardServer - - def _make_stream_ssh_forward_server_class(self, remote_address_): - return _ThreadingStreamForwardServer if self._threaded \ - else _StreamForwardServer + count = len(remote_binds) - len(local_binds) + if count < 0: + raise ValueError('Too many local bind addresses ' + '(local_bind_addresses > remote_bind_addresses)') + local_binds.extend([('0.0.0.0', 0) for x in range(count)]) + return local_binds - def _make_ssh_forward_server(self, remote_address, local_bind_address): + @staticmethod + def _consolidate_auth(ssh_password=None, + ssh_pkey=None, + ssh_pkey_password=None, + allow_agent=True, + host_pkey_directories=None, + logger=None): """ - Make SSH forward proxy Server class + Get sure authentication information is in place. + ``ssh_pkey`` may be of classes: + - ``str`` - in this case it represents a private key file; public + key will be obtained from it + - ``paramiko.Pkey`` - it will be transparently added to loaded keys + """ - _Handler = self._make_ssh_forward_handler_class(remote_address) - try: - forward_maker_class = self._make_stream_ssh_forward_server_class \ - if isinstance(local_bind_address, string_types) \ - else self._make_ssh_forward_server_class - _Server = forward_maker_class(remote_address) - ssh_forward_server = _Server( - local_bind_address, - _Handler, - logger=self.logger, - ) + ssh_loaded_pkeys = SSHTunnelForwarder.get_keys( + logger=logger, + host_pkey_directories=host_pkey_directories, + allow_agent=allow_agent + ) - if ssh_forward_server: - ssh_forward_server.daemon_threads = self.daemon_forward_servers - self._server_list.append(ssh_forward_server) - self.tunnel_is_up[ssh_forward_server.server_address] = False - else: - self._raise( - BaseSSHTunnelForwarderError, - 'Problem setting up ssh {0} <> {1} forwarder. You can ' - 'suppress this exception by using the `mute_exceptions`' - 'argument'.format(address_to_str(local_bind_address), - address_to_str(remote_address)) - ) - except IOError: - self._raise( - BaseSSHTunnelForwarderError, - "Couldn't open tunnel {0} <> {1} might be in use or " - "destination not reachable".format( - address_to_str(local_bind_address), - address_to_str(remote_address) + if isinstance(ssh_pkey, string_types): + ssh_pkey_expanded = os.path.expanduser(ssh_pkey) + if os.path.exists(ssh_pkey_expanded): + ssh_pkey = SSHTunnelForwarder.read_private_key_file( + pkey_file=ssh_pkey_expanded, + pkey_password=ssh_pkey_password or ssh_password, + logger=logger ) - ) + elif logger: + logger.warning('Private key file not found: {0}' + .format(ssh_pkey)) + if isinstance(ssh_pkey, paramiko.pkey.PKey): + ssh_loaded_pkeys.insert(0, ssh_pkey) + + if not ssh_password and not ssh_loaded_pkeys: + raise ValueError('No password or public key available!') + return (ssh_password, ssh_loaded_pkeys) + + @staticmethod + def _get_binds(bind_address, bind_addresses, is_remote=False): + addr_kind = 'remote' if is_remote else 'local' + + if not bind_address and not bind_addresses: + if is_remote: + raise ValueError("No {0} bind addresses specified. Use " + "'{0}_bind_address' or '{0}_bind_addresses'" + " argument".format(addr_kind)) + else: + return [] + elif bind_address and bind_addresses: + raise ValueError("You can't use both '{0}_bind_address' and " + "'{0}_bind_addresses' arguments. Use one of " + "them.".format(addr_kind)) + if bind_address: + bind_addresses = [bind_address] + if not is_remote: + # Add random port if missing in local bind + for (i, local_bind) in enumerate(bind_addresses): + if isinstance(local_bind, tuple) and len(local_bind) == 1: + bind_addresses[i] = (local_bind[0], 0) + check_addresses(bind_addresses, is_remote) + return bind_addresses + + @staticmethod + def _process_deprecated(attrib, deprecated_attrib, kwargs): + """ + Processes optional deprecate arguments + """ + if deprecated_attrib not in _DEPRECATIONS: + raise ValueError('{0} not included in deprecations list' + .format(deprecated_attrib)) + if deprecated_attrib in kwargs: + warnings.warn("'{0}' is DEPRECATED use '{1}' instead" + .format(deprecated_attrib, + _DEPRECATIONS[deprecated_attrib]), + DeprecationWarning) + if attrib: + raise ValueError("You can't use both '{0}' and '{1}'. " + "Please only use one of them" + .format(deprecated_attrib, + _DEPRECATIONS[deprecated_attrib])) + else: + return kwargs.pop(deprecated_attrib) + return attrib def __init__( self, @@ -984,66 +1012,151 @@ def __init__( self.logger.debug('Concurrent connections allowed: {0}' .format(self._threaded)) - @staticmethod - def _read_ssh_config(ssh_host, - ssh_config_file, - ssh_username=None, - ssh_pkey=None, - ssh_port=None, - ssh_proxy=None, - compression=None, - logger=None): + def __del__(self): + if self.is_active or self.is_alive: + self.logger.warning( + "It looks like you didn't call the .stop() before " + "the SSHTunnelForwarder obj was collected by " + "the garbage collector! Running .stop(force=True)") + self.stop(force=True) + + def local_is_up(self, target): """ - Read ssh_config_file and tries to look for user (ssh_username), - identityfile (ssh_pkey), port (ssh_port) and proxycommand - (ssh_proxy) entries for ssh_host + Check if a tunnel is up (remote target's host is reachable on TCP + target's port) + + Arguments: + target (tuple): + tuple of type (``str``, ``int``) indicating the listen IP + address and port + Return: + boolean + + .. deprecated:: 0.1.0 + Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up` """ - ssh_config = paramiko.SSHConfig() - if not ssh_config_file: # handle case where it's an empty string - ssh_config_file = None + try: + check_address(target) + except ValueError: + self.logger.warning('Target must be a tuple (IP, port), where IP ' + 'is a string (i.e. "192.168.0.1") and port is ' + 'an integer (i.e. 40000). Alternatively ' + 'target can be a valid UNIX domain socket.') + return False - # Try to read SSH_CONFIG_FILE + self.check_tunnels() + return self.tunnel_is_up.get(target, True) + + def _check_tunnel(self, _srv): + """ Check if tunnel is already established """ + if self.skip_tunnel_checkup: + self.tunnel_is_up[_srv.local_address] = True + return + self.logger.info('Checking tunnel to: {0}'.format(_srv.remote_address)) + if isinstance(_srv.local_address, string_types): # UNIX stream + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + else: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(TUNNEL_TIMEOUT) try: - # open the ssh config file - with open(os.path.expanduser(ssh_config_file), 'r') as f: - ssh_config.parse(f) - # looks for information for the destination system - hostname_info = ssh_config.lookup(ssh_host) - # gather settings for user, port and identity file - # last resort: use the 'login name' of the user - ssh_username = ( - ssh_username or - hostname_info.get('user') + # Windows raises WinError 10049 if trying to connect to 0.0.0.0 + connect_to = ('127.0.0.1', _srv.local_port) \ + if _srv.local_host == '0.0.0.0' else _srv.local_address + s.connect(connect_to) + self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get( + timeout=TUNNEL_TIMEOUT * 1.1 ) - ssh_pkey = ( - ssh_pkey or - hostname_info.get('identityfile', [None])[0] + self.logger.debug( + 'Tunnel to {0} is DOWN'.format(_srv.remote_address) ) - ssh_host = hostname_info.get('hostname') - ssh_port = ssh_port or hostname_info.get('port') + except socket.error: + self.logger.debug( + 'Tunnel to {0} is DOWN'.format(_srv.remote_address) + ) + self.tunnel_is_up[_srv.local_address] = False - proxycommand = hostname_info.get('proxycommand') - ssh_proxy = ssh_proxy or (paramiko.ProxyCommand(proxycommand) if - proxycommand else None) - if compression is None: - compression = hostname_info.get('compression', '') - compression = True if compression.upper() == 'YES' else False + except queue.Empty: + self.logger.debug( + 'Tunnel to {0} is UP'.format(_srv.remote_address) + ) + self.tunnel_is_up[_srv.local_address] = True + finally: + s.close() + + def check_tunnels(self): + """ + Check that if all tunnels are established and populates + :attr:`.tunnel_is_up` + """ + skip_tunnel_checkup = self.skip_tunnel_checkup + try: + # force tunnel check at this point + self.skip_tunnel_checkup = False + for _srv in self._server_list: + self._check_tunnel(_srv) + finally: + self.skip_tunnel_checkup = skip_tunnel_checkup # roll it back + + def _make_ssh_forward_handler_class(self, remote_address_): + """ + Make SSH Handler class + """ + class Handler(_ForwardHandler): + remote_address = remote_address_ + ssh_transport = self._transport + logger = self.logger + return Handler + + def _make_ssh_forward_server_class(self, remote_address_): + return _ThreadingForwardServer if self._threaded else _ForwardServer + + def _make_stream_ssh_forward_server_class(self, remote_address_): + return _ThreadingStreamForwardServer if self._threaded \ + else _StreamForwardServer + + def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None): + if self._raise_fwd_exc: + raise exception(reason) + else: + self.logger.error(repr(exception(reason))) + + def _make_ssh_forward_server(self, remote_address, local_bind_address): + """ + Make SSH forward proxy Server class + """ + _Handler = self._make_ssh_forward_handler_class(remote_address) + try: + forward_maker_class = self._make_stream_ssh_forward_server_class \ + if isinstance(local_bind_address, string_types) \ + else self._make_ssh_forward_server_class + _Server = forward_maker_class(remote_address) + ssh_forward_server = _Server( + local_bind_address, + _Handler, + logger=self.logger, + ) + + if ssh_forward_server: + ssh_forward_server.daemon_threads = self.daemon_forward_servers + self._server_list.append(ssh_forward_server) + self.tunnel_is_up[ssh_forward_server.server_address] = False + else: + self._raise( + BaseSSHTunnelForwarderError, + 'Problem setting up ssh {0} <> {1} forwarder. You can ' + 'suppress this exception by using the `mute_exceptions`' + 'argument'.format(address_to_str(local_bind_address), + address_to_str(remote_address)) + ) except IOError: - if logger: - logger.warning( - 'Could not read SSH configuration file: {0}' - .format(ssh_config_file) + self._raise( + BaseSSHTunnelForwarderError, + "Couldn't open tunnel {0} <> {1} might be in use or " + "destination not reachable".format( + address_to_str(local_bind_address), + address_to_str(remote_address) ) - except (AttributeError, TypeError): # ssh_config_file is None - if logger: - logger.info('Skipping loading of ssh configuration file') - finally: - return (ssh_host, - ssh_username or getpass.getuser(), - ssh_pkey, - int(ssh_port) if ssh_port else 22, # fallback value - ssh_proxy, - compression) + ) @staticmethod def get_agent_keys(logger=None): @@ -1119,93 +1232,115 @@ def get_keys( # noqa: C901 too complex logger.info('{0} key(s) loaded'.format(len(keys))) return keys - @staticmethod - def _consolidate_binds(local_binds, remote_binds): - """ - Fill local_binds with defaults when no value/s were specified, - leaving paramiko to decide in which local port the tunnel will be open - """ - count = len(remote_binds) - len(local_binds) - if count < 0: - raise ValueError('Too many local bind addresses ' - '(local_bind_addresses > remote_bind_addresses)') - local_binds.extend([('0.0.0.0', 0) for x in range(count)]) - return local_binds - - @staticmethod - def _consolidate_auth(ssh_password=None, - ssh_pkey=None, - ssh_pkey_password=None, - allow_agent=True, - host_pkey_directories=None, - logger=None): - """ - Get sure authentication information is in place. - ``ssh_pkey`` may be of classes: - - ``str`` - in this case it represents a private key file; public - key will be obtained from it - - ``paramiko.Pkey`` - it will be transparently added to loaded keys + def _get_transport(self): + """ Return the SSH transport to the remote gateway """ + if self.ssh_proxy: + if isinstance(self.ssh_proxy, paramiko.proxy.ProxyCommand): + proxy_repr = repr(self.ssh_proxy.cmd[1]) + else: + proxy_repr = repr(self.ssh_proxy) + self.logger.debug('Connecting via proxy: {0}'.format(proxy_repr)) + _socket = self.ssh_proxy + else: + _socket = (self.ssh_host, self.ssh_port) + if isinstance(_socket, socket.socket): + _socket.settimeout(SSH_TIMEOUT) + _socket.connect((self.ssh_host, self.ssh_port)) + transport = paramiko.Transport(_socket) + sock = transport.sock + if isinstance(sock, socket.socket): + sock.settimeout(SSH_TIMEOUT) + transport.set_keepalive(self.set_keepalive) + transport.use_compression(compress=self.compression) + transport.daemon = self.daemon_transport + # try to solve https://github.com/paramiko/paramiko/issues/1181 + # transport.banner_timeout = 200 + if isinstance(sock, socket.socket): + sock_timeout = sock.gettimeout() + sock_info = repr((sock.family, sock.type, sock.proto)) + self.logger.debug('Transport socket info: {0}, timeout={1}' + .format(sock_info, sock_timeout)) + return transport - """ - ssh_loaded_pkeys = SSHTunnelForwarder.get_keys( - logger=logger, - host_pkey_directories=host_pkey_directories, - allow_agent=allow_agent - ) + def _check_is_started(self): + if not self.is_active: # underlying transport not alive + msg = 'Server is not started. Please .start() first!' + raise BaseSSHTunnelForwarderError(msg) + if not self.is_alive: + msg = 'Tunnels are not started. Please .start() first!' + raise HandlerSSHTunnelForwarderError(msg) - if isinstance(ssh_pkey, string_types): - ssh_pkey_expanded = os.path.expanduser(ssh_pkey) - if os.path.exists(ssh_pkey_expanded): - ssh_pkey = SSHTunnelForwarder.read_private_key_file( - pkey_file=ssh_pkey_expanded, - pkey_password=ssh_pkey_password or ssh_password, - logger=logger - ) - elif logger: - logger.warning('Private key file not found: {0}' - .format(ssh_pkey)) - if isinstance(ssh_pkey, paramiko.pkey.PKey): - ssh_loaded_pkeys.insert(0, ssh_pkey) + def _stop_transport(self, force=False): + """ Close the underlying transport when nothing more is needed """ + try: + self._check_is_started() + except (BaseSSHTunnelForwarderError, + HandlerSSHTunnelForwarderError) as e: + self.logger.warning(e) + if force and self.is_active: + # don't wait connections + self.logger.info('Closing ssh transport') + self._transport.close() + self._transport.stop_thread() + for _srv in self._server_list: + status = 'up' if self.tunnel_is_up[_srv.local_address] else 'down' + self.logger.info('Shutting down tunnel: {0} <> {1} ({2})'.format( + address_to_str(_srv.local_address), + address_to_str(_srv.remote_address), + status + )) + _srv.shutdown() + _srv.server_close() + # clean up the UNIX domain socket if we're using one + if isinstance(_srv, _StreamForwardServer): + try: + os.unlink(_srv.local_address) + except Exception as e: + self.logger.error('Unable to unlink socket {0}: {1}' + .format(_srv.local_address, repr(e))) + self.is_alive = False + if self.is_active: + self.logger.info('Closing ssh transport') + self._transport.close() + self._transport.stop_thread() + self.logger.debug('Transport is closed') - if not ssh_password and not ssh_loaded_pkeys: - raise ValueError('No password or public key available!') - return (ssh_password, ssh_loaded_pkeys) + def _connect_to_gateway(self): + """ + Open connection to SSH gateway + - First try with all keys loaded from an SSH agent (if allowed) + - Then with those passed directly or read from ~/.ssh/config + - As last resort, try with a provided password + """ + for key in self.ssh_pkeys: + self.logger.debug('Trying to log in with key: {0}' + .format(hexlify(key.get_fingerprint()))) + try: + self._transport = self._get_transport() + self._transport.connect(hostkey=self.ssh_host_key, + username=self.ssh_username, + pkey=key) + if self._transport.is_alive: + return + except paramiko.AuthenticationException: + self.logger.debug('Authentication error') + self._stop_transport() - def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None): - if self._raise_fwd_exc: - raise exception(reason) - else: - self.logger.error(repr(exception(reason))) + if self.ssh_password: # avoid conflict using both pass and pkey + self.logger.debug('Trying to log in with password: {0}' + .format('*' * len(self.ssh_password))) + try: + self._transport = self._get_transport() + self._transport.connect(hostkey=self.ssh_host_key, + username=self.ssh_username, + password=self.ssh_password) + if self._transport.is_alive: + return + except paramiko.AuthenticationException: + self.logger.debug('Authentication error') + self._stop_transport() - def _get_transport(self): - """ Return the SSH transport to the remote gateway """ - if self.ssh_proxy: - if isinstance(self.ssh_proxy, paramiko.proxy.ProxyCommand): - proxy_repr = repr(self.ssh_proxy.cmd[1]) - else: - proxy_repr = repr(self.ssh_proxy) - self.logger.debug('Connecting via proxy: {0}'.format(proxy_repr)) - _socket = self.ssh_proxy - else: - _socket = (self.ssh_host, self.ssh_port) - if isinstance(_socket, socket.socket): - _socket.settimeout(SSH_TIMEOUT) - _socket.connect((self.ssh_host, self.ssh_port)) - transport = paramiko.Transport(_socket) - sock = transport.sock - if isinstance(sock, socket.socket): - sock.settimeout(SSH_TIMEOUT) - transport.set_keepalive(self.set_keepalive) - transport.use_compression(compress=self.compression) - transport.daemon = self.daemon_transport - # try to solve https://github.com/paramiko/paramiko/issues/1181 - # transport.banner_timeout = 200 - if isinstance(sock, socket.socket): - sock_timeout = sock.gettimeout() - sock_info = repr((sock.family, sock.type, sock.proto)) - self.logger.debug('Transport socket info: {0}, timeout={1}' - .format(sock_info, sock_timeout)) - return transport + self.logger.error('Could not open connection to gateway') def _create_tunnels(self): """ @@ -1231,53 +1366,6 @@ def _create_tunnels(self): msg = 'Problem setting SSH Forwarder up: {0}'.format(e.value) self.logger.error(msg) - @staticmethod - def _get_binds(bind_address, bind_addresses, is_remote=False): - addr_kind = 'remote' if is_remote else 'local' - - if not bind_address and not bind_addresses: - if is_remote: - raise ValueError("No {0} bind addresses specified. Use " - "'{0}_bind_address' or '{0}_bind_addresses'" - " argument".format(addr_kind)) - else: - return [] - elif bind_address and bind_addresses: - raise ValueError("You can't use both '{0}_bind_address' and " - "'{0}_bind_addresses' arguments. Use one of " - "them.".format(addr_kind)) - if bind_address: - bind_addresses = [bind_address] - if not is_remote: - # Add random port if missing in local bind - for (i, local_bind) in enumerate(bind_addresses): - if isinstance(local_bind, tuple) and len(local_bind) == 1: - bind_addresses[i] = (local_bind[0], 0) - check_addresses(bind_addresses, is_remote) - return bind_addresses - - @staticmethod - def _process_deprecated(attrib, deprecated_attrib, kwargs): - """ - Processes optional deprecate arguments - """ - if deprecated_attrib not in _DEPRECATIONS: - raise ValueError('{0} not included in deprecations list' - .format(deprecated_attrib)) - if deprecated_attrib in kwargs: - warnings.warn("'{0}' is DEPRECATED use '{1}' instead" - .format(deprecated_attrib, - _DEPRECATIONS[deprecated_attrib]), - DeprecationWarning) - if attrib: - raise ValueError("You can't use both '{0}' and '{1}'. " - "Please only use one of them" - .format(deprecated_attrib, - _DEPRECATIONS[deprecated_attrib])) - else: - return kwargs.pop(deprecated_attrib) - return attrib - @staticmethod def read_private_key_file(pkey_file, pkey_password=None, @@ -1323,6 +1411,21 @@ def read_private_key_file(pkey_file, .format(pkey_file, pkey_class)) return ssh_pkey + def _serve_forever_wrapper(self, _srv, poll_interval=0.1): + """ + Wrapper for the server created for a SSH forward + """ + self.logger.info('Opening tunnel: {0} <> {1}'.format( + address_to_str(_srv.local_address), + address_to_str(_srv.remote_address)) + ) + _srv.serve_forever(poll_interval) # blocks until finished + + self.logger.info('Tunnel: {0} <> {1} released'.format( + address_to_str(_srv.local_address), + address_to_str(_srv.remote_address)) + ) + def start(self): """ Start the SSH tunnels """ if self.is_alive: @@ -1391,93 +1494,6 @@ def restart(self): self.stop() self.start() - def _connect_to_gateway(self): - """ - Open connection to SSH gateway - - First try with all keys loaded from an SSH agent (if allowed) - - Then with those passed directly or read from ~/.ssh/config - - As last resort, try with a provided password - """ - for key in self.ssh_pkeys: - self.logger.debug('Trying to log in with key: {0}' - .format(hexlify(key.get_fingerprint()))) - try: - self._transport = self._get_transport() - self._transport.connect(hostkey=self.ssh_host_key, - username=self.ssh_username, - pkey=key) - if self._transport.is_alive: - return - except paramiko.AuthenticationException: - self.logger.debug('Authentication error') - self._stop_transport() - - if self.ssh_password: # avoid conflict using both pass and pkey - self.logger.debug('Trying to log in with password: {0}' - .format('*' * len(self.ssh_password))) - try: - self._transport = self._get_transport() - self._transport.connect(hostkey=self.ssh_host_key, - username=self.ssh_username, - password=self.ssh_password) - if self._transport.is_alive: - return - except paramiko.AuthenticationException: - self.logger.debug('Authentication error') - self._stop_transport() - - self.logger.error('Could not open connection to gateway') - - def _serve_forever_wrapper(self, _srv, poll_interval=0.1): - """ - Wrapper for the server created for a SSH forward - """ - self.logger.info('Opening tunnel: {0} <> {1}'.format( - address_to_str(_srv.local_address), - address_to_str(_srv.remote_address)) - ) - _srv.serve_forever(poll_interval) # blocks until finished - - self.logger.info('Tunnel: {0} <> {1} released'.format( - address_to_str(_srv.local_address), - address_to_str(_srv.remote_address)) - ) - - def _stop_transport(self, force=False): - """ Close the underlying transport when nothing more is needed """ - try: - self._check_is_started() - except (BaseSSHTunnelForwarderError, - HandlerSSHTunnelForwarderError) as e: - self.logger.warning(e) - if force and self.is_active: - # don't wait connections - self.logger.info('Closing ssh transport') - self._transport.close() - self._transport.stop_thread() - for _srv in self._server_list: - status = 'up' if self.tunnel_is_up[_srv.local_address] else 'down' - self.logger.info('Shutting down tunnel: {0} <> {1} ({2})'.format( - address_to_str(_srv.local_address), - address_to_str(_srv.remote_address), - status - )) - _srv.shutdown() - _srv.server_close() - # clean up the UNIX domain socket if we're using one - if isinstance(_srv, _StreamForwardServer): - try: - os.unlink(_srv.local_address) - except Exception as e: - self.logger.error('Unable to unlink socket {0}: {1}' - .format(_srv.local_address, repr(e))) - self.is_alive = False - if self.is_active: - self.logger.info('Closing ssh transport') - self._transport.close() - self._transport.stop_thread() - self.logger.debug('Transport is closed') - @property def local_bind_port(self): # BACKWARDS COMPATIBILITY @@ -1553,13 +1569,16 @@ def is_active(self): return True return False - def _check_is_started(self): - if not self.is_active: # underlying transport not alive - msg = 'Server is not started. Please .start() first!' - raise BaseSSHTunnelForwarderError(msg) - if not self.is_alive: - msg = 'Tunnels are not started. Please .start() first!' - raise HandlerSSHTunnelForwarderError(msg) + def __exit__(self, *args): + self.stop(force=True) + + def __enter__(self): + try: + self.start() + return self + except KeyboardInterrupt: + self.__exit__() + raise def __str__(self): credentials = { @@ -1605,25 +1624,6 @@ def __str__(self): def __repr__(self): return self.__str__() - def __enter__(self): - try: - self.start() - return self - except KeyboardInterrupt: - self.__exit__() - raise - - def __exit__(self, *args): - self.stop(force=True) - - def __del__(self): - if self.is_active or self.is_alive: - self.logger.warning( - "It looks like you didn't call the .stop() before " - "the SSHTunnelForwarder obj was collected by " - "the garbage collector! Running .stop(force=True)") - self.stop(force=True) - def open_tunnel(*args, **kwargs): """ diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 40662d0..db8ca01 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -49,6 +49,9 @@ def get_random_string(length=12): return ''.join([random.choice(asciis) for _ in range(length)]) +HERE = path.abspath(path.dirname(__file__)) + + def get_test_data_path(x): return path.join(HERE, x) @@ -90,7 +93,6 @@ def capture_stdout_stderr(): 'ecdsa-sha2-nistp256': ECDSA, } DAEMON_THREADS = False -HERE = path.abspath(path.dirname(__file__)) THREADS_TIMEOUT = 5.0 PKEY_FILE = 'testrsa.key' ENCRYPTED_PKEY_FILE = 'testrsa_encrypted.key' @@ -264,6 +266,74 @@ def wait_for_thread(self, thread, timeout=THREADS_TIMEOUT, who=None): thread.name)) thread.join(timeout) + def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): + self.log.debug('forward-server Start') + self.ssh_event.wait(THREADS_TIMEOUT) # wait for SSH server's transport + try: + schan = self.ts.accept(timeout=timeout) + info = "forward-server schan <> echo" + self.log.info(info + " accept()") + echo = socket.create_connection( + (self.eaddr, self.eport) + ) + while self.is_server_working: + rqst, _, _ = select.select([schan, echo], + [], + [], + timeout) + if schan in rqst: + data = schan.recv(1024) + self.log.debug('{0} -->: {1}'.format(info, repr(data))) + echo.send(data) + if len(data) == 0: + break + if echo in rqst: + data = echo.recv(1024) + self.log.debug('{0} <--: {1}'.format(info, repr(data))) + schan.send(data) + if len(data) == 0: + break + self.log.info('<<< forward-server received STOP signal') + except socket.error: + self.log.critical('{0} sending RST'.format(info)) + # except Exception as e: + # # we reach this point usually when schan is None (paramiko bug?) + # self.log.critical(repr(e)) + finally: + if schan: + self.log.debug('{0} closing connection...'.format(info)) + schan.close() + echo.close() + self.log.debug('{0} connection closed.'.format(info)) + + def _run_ssh_server(self): + self.log.info('ssh-server Start') + try: + self.socks, addr = self.ssockl.accept() + except socket.timeout: + self.log.error('ssh-server connection timed out!') + self.running_threads.remove('ssh-server') + return + self.ts = paramiko.Transport(self.socks) + host_key = paramiko.RSAKey.from_private_key_file( + get_test_data_path(PKEY_FILE) + ) + self.ts.add_server_key(host_key) + server = NullServer(allowed_keys=FINGERPRINTS.keys(), + log=self.log) + t = threading.Thread(target=self._do_forwarding, + name='forward-server') + t.daemon = DAEMON_THREADS + self.running_threads.append(t.name) + self.threads[t.name] = t + t.start() + self.ts.start_server(self.ssh_event, server) + self.wait_for_thread(t, + timeout=None, + who='ssh-server') + self.log.info('ssh-server shutting down') + self.running_threads.remove('ssh-server') + def start_echo_and_ssh_server(self): self.is_server_working = True self.start_echo_server() @@ -296,42 +366,6 @@ def _test_server(self, *args, **kwargs): yield server server._stop_transport() - def start_echo_server(self): - t = threading.Thread(target=self._run_echo_server, - name='echo-server') - t.daemon = DAEMON_THREADS - self.running_threads.append(t.name) - self.threads[t.name] = t - t.start() - - def _run_ssh_server(self): - self.log.info('ssh-server Start') - try: - self.socks, addr = self.ssockl.accept() - except socket.timeout: - self.log.error('ssh-server connection timed out!') - self.running_threads.remove('ssh-server') - return - self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file( - get_test_data_path(PKEY_FILE) - ) - self.ts.add_server_key(host_key) - server = NullServer(allowed_keys=FINGERPRINTS.keys(), - log=self.log) - t = threading.Thread(target=self._do_forwarding, - name='forward-server') - t.daemon = DAEMON_THREADS - self.running_threads.append(t.name) - self.threads[t.name] = t - t.start() - self.ts.start_server(self.ssh_event, server) - self.wait_for_thread(t, - timeout=None, - who='ssh-server') - self.log.info('ssh-server shutting down') - self.running_threads.remove('ssh-server') - def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): self.log.info('echo-server Started') self.ssh_event.wait(timeout) # wait for transport @@ -380,45 +414,13 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): self.log.info('echo-server shutting down') self.running_threads.remove('echo-server') - def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): - self.log.debug('forward-server Start') - self.ssh_event.wait(THREADS_TIMEOUT) # wait for SSH server's transport - try: - schan = self.ts.accept(timeout=timeout) - info = "forward-server schan <> echo" - self.log.info(info + " accept()") - echo = socket.create_connection( - (self.eaddr, self.eport) - ) - while self.is_server_working: - rqst, _, _ = select.select([schan, echo], - [], - [], - timeout) - if schan in rqst: - data = schan.recv(1024) - self.log.debug('{0} -->: {1}'.format(info, repr(data))) - echo.send(data) - if len(data) == 0: - break - if echo in rqst: - data = echo.recv(1024) - self.log.debug('{0} <--: {1}'.format(info, repr(data))) - schan.send(data) - if len(data) == 0: - break - self.log.info('<<< forward-server received STOP signal') - except socket.error: - self.log.critical('{0} sending RST'.format(info)) - # except Exception as e: - # # we reach this point usually when schan is None (paramiko bug?) - # self.log.critical(repr(e)) - finally: - if schan: - self.log.debug('{0} closing connection...'.format(info)) - schan.close() - echo.close() - self.log.debug('{0} connection closed.'.format(info)) + def start_echo_server(self): + t = threading.Thread(target=self._run_echo_server, + name='echo-server') + t.daemon = DAEMON_THREADS + self.running_threads.append(t.name) + self.threads[t.name] = t + t.start() def randomize_eport(self): return random.randint(49152, 65535) @@ -1194,6 +1196,25 @@ def test_get_keys(self): class AuxiliaryTest(unittest.TestCase): """ Set of tests that do not need the mock SSH server or logger """ + def _test_parser(self, parser): + self.assertEqual(parser['ssh_address'], '10.10.10.10') + self.assertEqual(parser['ssh_username'], getpass.getuser()) + self.assertEqual(parser['ssh_port'], 22) + self.assertEqual(parser['ssh_password'], SSH_PASSWORD) + self.assertListEqual(parser['remote_bind_addresses'], + [('10.0.0.1', 8080), ('10.0.0.2', 8080)]) + self.assertListEqual(parser['local_bind_addresses'], + [('', 8081), ('', 8082)]) + self.assertEqual(parser['ssh_host_key'], str(SSH_DSS)) + self.assertEqual(parser['ssh_private_key'], __file__) + self.assertEqual(parser['ssh_private_key_password'], SSH_PASSWORD) + self.assertTrue(parser['threaded']) + self.assertEqual(parser['verbose'], 3) + self.assertEqual(parser['ssh_proxy'], ('10.0.0.2', 22)) + self.assertEqual(parser['ssh_config_file'], 'ssh_config') + self.assertTrue(parser['compression']) + self.assertFalse(parser['allow_agent']) + def test_parse_arguments_short(self): """ Test CLI argument parsing with short parameter names """ args = ['10.10.10.10', # ssh_address @@ -1245,25 +1266,6 @@ def test_parse_arguments_long(self): ) self._test_parser(parser) - def _test_parser(self, parser): - self.assertEqual(parser['ssh_address'], '10.10.10.10') - self.assertEqual(parser['ssh_username'], getpass.getuser()) - self.assertEqual(parser['ssh_port'], 22) - self.assertEqual(parser['ssh_password'], SSH_PASSWORD) - self.assertListEqual(parser['remote_bind_addresses'], - [('10.0.0.1', 8080), ('10.0.0.2', 8080)]) - self.assertListEqual(parser['local_bind_addresses'], - [('', 8081), ('', 8082)]) - self.assertEqual(parser['ssh_host_key'], str(SSH_DSS)) - self.assertEqual(parser['ssh_private_key'], __file__) - self.assertEqual(parser['ssh_private_key_password'], SSH_PASSWORD) - self.assertTrue(parser['threaded']) - self.assertEqual(parser['verbose'], 3) - self.assertEqual(parser['ssh_proxy'], ('10.0.0.2', 22)) - self.assertEqual(parser['ssh_config_file'], 'ssh_config') - self.assertTrue(parser['compression']) - self.assertFalse(parser['allow_agent']) - def test_bindlist(self): """ Test that _bindlist enforces IP:PORT format for local and remote binds From ab1cd1669a3544bb31c9be75fd3d8c412118f9fd Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:27:12 -0500 Subject: [PATCH 04/46] consolidate HERE into get_test_data_path --- tests/test_forwarder.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index db8ca01..5082d1f 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -49,11 +49,8 @@ def get_random_string(length=12): return ''.join([random.choice(asciis) for _ in range(length)]) -HERE = path.abspath(path.dirname(__file__)) - - def get_test_data_path(x): - return path.join(HERE, x) + return path.join(path.abspath(path.dirname(__file__)), x) @contextmanager From 7eeec985a9489364c2943e6212b1c190723bfa03 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:29:16 -0500 Subject: [PATCH 05/46] fix W504 line break after binary operator --- sshtunnel.py | 24 ++++++++++++++---------- tests/test_forwarder.py | 6 ++++-- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index e9d3c66..7d9dfde 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -113,8 +113,10 @@ def check_address(address): elif isinstance(address, string_types): if os.name != 'posix': raise ValueError('Platform does not support UNIX domain sockets') - if not (os.path.exists(address) or - os.access(os.path.dirname(address), os.W_OK)): + if not ( + os.path.exists(address) + or os.access(os.path.dirname(address), os.W_OK) + ): raise ValueError('ADDRESS not a valid socket domain socket ({0})' .format(address)) else: @@ -340,8 +342,10 @@ def _redirect(self, chan): def handle(self): uid = generate_random_string(5) - self.info = '#{0} <-- {1}'.format(uid, self.client_address or - self.server.local_address) + self.info = '#{0} <-- {1}'.format( + uid, self.client_address + or self.server.local_address + ) src_address = self.request.getpeername() if not isinstance(src_address, tuple): src_address = ('dummy', 12345) @@ -776,12 +780,12 @@ def _read_ssh_config(ssh_host, # gather settings for user, port and identity file # last resort: use the 'login name' of the user ssh_username = ( - ssh_username or - hostname_info.get('user') + ssh_username + or hostname_info.get('user') ) ssh_pkey = ( - ssh_pkey or - hostname_info.get('identityfile', [None])[0] + ssh_pkey + or hostname_info.get('identityfile', [None])[0] ) ssh_host = hostname_info.get('hostname') ssh_port = ssh_port or hostname_info.get('port') @@ -1563,8 +1567,8 @@ def tunnel_bindings(self): def is_active(self): """ Return True if the underlying SSH transport is up """ if ( - '_transport' in self.__dict__ and - self._transport.is_active() + '_transport' in self.__dict__ + and self._transport.is_active() ): return True return False diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 5082d1f..2ca5b1c 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -160,8 +160,10 @@ def check_auth_password(self, username, password): def check_auth_publickey(self, username, key): try: expected = FINGERPRINTS[key.get_name()] - _ok = (key.get_name() in self.__allowed_keys and - key.get_fingerprint() == expected) + _ok = ( + key.get_name() in self.__allowed_keys + and key.get_fingerprint() == expected + ) except KeyError: _ok = False self.log.debug('NullServer >> pkey authentication for {0} {1}OK' From b68b829ca0330bcc4027cb2012cc41b0d1c6ad9a Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:30:13 -0500 Subject: [PATCH 06/46] sort imports --- docs/conf.py | 7 ++++--- e2e_tests/run_docker_e2e_db_tests.py | 14 ++++++++------ e2e_tests/run_docker_e2e_hangs_tests.py | 2 +- setup.py | 3 ++- sshtunnel.py | 12 ++++++------ tests/test_forwarder.py | 19 ++++++++++--------- 6 files changed, 31 insertions(+), 26 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 7be950a..fabb7de 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,15 +13,16 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os - -import sshtunnel +import sys # Patch to disable warning on non-local image import sphinx.environment from docutils.utils import get_source_line +import sshtunnel + + def _warn_node(self, msg, node): if not msg.startswith('nonlocal image URI found:'): self._warnfunc(msg, '%s:%s' % get_source_line(node)) diff --git a/e2e_tests/run_docker_e2e_db_tests.py b/e2e_tests/run_docker_e2e_db_tests.py index b9ea4df..fea8ee2 100644 --- a/e2e_tests/run_docker_e2e_db_tests.py +++ b/e2e_tests/run_docker_e2e_db_tests.py @@ -1,14 +1,16 @@ +import logging +import os import select -import traceback import sys -import os -import time -from sshtunnel import SSHTunnelForwarder -import sshtunnel -import logging import threading +import time +import traceback + import paramiko +import sshtunnel +from sshtunnel import SSHTunnelForwarder + sshtunnel.DEFAULT_LOGLEVEL = 1 logging.basicConfig( format='%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/%(lineno)04d@%(module)-10.9s| %(message)s', level=1) diff --git a/e2e_tests/run_docker_e2e_hangs_tests.py b/e2e_tests/run_docker_e2e_hangs_tests.py index 0ec7449..7abd491 100644 --- a/e2e_tests/run_docker_e2e_hangs_tests.py +++ b/e2e_tests/run_docker_e2e_hangs_tests.py @@ -1,7 +1,7 @@ import logging -import sshtunnel import os +import sshtunnel if __name__ == '__main__': path = os.path.join(os.path.dirname(__file__), 'run_docker_e2e_db_tests.py') diff --git a/setup.py b/setup.py index ccaaab8..fc6e4da 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,9 @@ """ import re -from os import path from codecs import open # To use a consistent encoding +from os import path + from setuptools import setup # Always prefer setuptools over distutils here = path.abspath(path.dirname(__file__)) diff --git a/sshtunnel.py b/sshtunnel.py index 7d9dfde..8f09223 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -11,18 +11,18 @@ """ +import argparse +import getpass +import logging import os import random +import socket import string import sys -import socket -import getpass -import logging -import argparse -import warnings import threading -from select import select +import warnings from binascii import hexlify +from select import select import paramiko diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 2ca5b1c..56b5c47 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1,24 +1,25 @@ from __future__ import with_statement +import argparse +import getpass +import logging import os -import sys import random import select +import shutil import socket -import getpass -import logging -import argparse -import warnings +import sys +import tempfile import threading -from os import path, linesep -from functools import partial +import warnings from contextlib import contextmanager +from functools import partial +from os import linesep, path import mock import paramiko + import sshtunnel -import shutil -import tempfile if sys.version_info[0] == 2: from cStringIO import StringIO From 7d5aa7f1509fde04f76a8768614cec2d3d1e517b Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:30:46 -0500 Subject: [PATCH 07/46] remove invalid `# noqa` directives --- sshtunnel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 8f09223..8b26ac1 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1212,7 +1212,6 @@ def get_keys( # noqa: C901 too complex 'dsa': paramiko.DSSKey, 'ecdsa': paramiko.ECDSAKey} if hasattr(paramiko, 'Ed25519Key'): - # NOQA: new in paramiko>=2.2: http://docs.paramiko.org/en/stable/api/keys.html#module-paramiko.ed25519key paramiko_key_types['ed25519'] = paramiko.Ed25519Key for directory in host_pkey_directories: for keytype in paramiko_key_types.keys(): @@ -1391,7 +1390,6 @@ def read_private_key_file(pkey_file, ssh_pkey = None key_types = (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey) if hasattr(paramiko, 'Ed25519Key'): - # NOQA: new in paramiko>=2.2: http://docs.paramiko.org/en/stable/api/keys.html#module-paramiko.ed25519key key_types += (paramiko.Ed25519Key, ) for pkey_class in (key_type,) if key_type else key_types: try: From 5b9a696ef1a86a93e1e7267578c36e07bdcba09a Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:31:22 -0500 Subject: [PATCH 08/46] remove unused conf.py imports --- docs/conf.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index fabb7de..64eb688 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,9 +13,6 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import os -import sys - # Patch to disable warning on non-local image import sphinx.environment from docutils.utils import get_source_line From 4f4b26ef8a0ca74bf5a0826f91b0f828006a1917 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:38:45 -0500 Subject: [PATCH 09/46] Prefer `TypeError` exception for invalid type --- sshtunnel.py | 2 +- tests/test_forwarder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 8b26ac1..eb60030 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -120,7 +120,7 @@ def check_address(address): raise ValueError('ADDRESS not a valid socket domain socket ({0})' .format(address)) else: - raise ValueError('ADDRESS is not a tuple, string, or character buffer ' + raise TypeError('ADDRESS is not a tuple, string, or character buffer ' '({0})'.format(type(address).__name__)) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 56b5c47..9937a82 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1400,5 +1400,5 @@ def test_check_address(self): self.assertIsNone(sshtunnel.check_addresses(address_list)) with self.assertRaises(ValueError): sshtunnel.check_address('this is not valid') - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): sshtunnel.check_address(-1) # that's not valid either From 824ca7318b83d3583ddd9a041651630bfa203786 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:46:52 -0500 Subject: [PATCH 10/46] use ast.literal_eval() instead of eval where possible --- e2e_tests/run_docker_e2e_db_tests.py | 5 +++-- setup.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/e2e_tests/run_docker_e2e_db_tests.py b/e2e_tests/run_docker_e2e_db_tests.py index fea8ee2..a2ca6ab 100644 --- a/e2e_tests/run_docker_e2e_db_tests.py +++ b/e2e_tests/run_docker_e2e_db_tests.py @@ -5,6 +5,7 @@ import threading import time import traceback +from ast import literal_eval import paramiko @@ -27,7 +28,7 @@ PG_USERNAME = 'postgres' PG_PASSWORD = 'postgres' PG_QUERY = 'select version()' -PG_EXPECT = eval( +PG_EXPECT = literal_eval( """('PostgreSQL 13.0 (Debian 13.0-1.pgdg100+1) on x86_64-pc-linux-gnu, compiled by gcc (Debian 8.3.0-6) 8.3.0, 64-bit',)""") MYSQL_DATABASE_NAME = 'main' @@ -40,7 +41,7 @@ MONGO_USERNAME = 'mongo' MONGO_PASSWORD = 'mongo' MONGO_QUERY = lambda client, db: client.server_info() -MONGO_EXPECT = eval( +MONGO_EXPECT = literal_eval( """{'version': '3.6.23', 'gitVersion': 'd352e6a4764659e0d0350ce77279de3c1f243e5c', 'modules': [], 'allocator': 'tcmalloc', 'javascriptEngine': 'mozjs', 'sysInfo': 'deprecated', 'versionArray': [3, 6, 23, 0], 'openssl': {'running': 'OpenSSL 1.0.2g 1 Mar 2016', 'compiled': 'OpenSSL 1.0.2g 1 Mar 2016'}, 'buildEnvironment': {'distmod': 'ubuntu1604', 'distarch': 'x86_64', 'cc': '/opt/mongodbtoolchain/v2/bin/gcc: gcc (GCC) 5.4.0', 'ccflags': '-fno-omit-frame-pointer -fno-strict-aliasing -ggdb -pthread -Wall -Wsign-compare -Wno-unknown-pragmas -Winvalid-pch -Werror -O2 -Wno-unused-local-typedefs -Wno-unused-function -Wno-deprecated-declarations -Wno-unused-but-set-variable -Wno-missing-braces -fstack-protector-strong -fno-builtin-memcmp', 'cxx': '/opt/mongodbtoolchain/v2/bin/g++: g++ (GCC) 5.4.0', 'cxxflags': '-Woverloaded-virtual -Wno-maybe-uninitialized -std=c++14', 'linkflags': '-pthread -Wl,-z,now -rdynamic -Wl,--fatal-warnings -fstack-protector-strong -fuse-ld=gold -Wl,--build-id -Wl,--hash-style=gnu -Wl,-z,noexecstack -Wl,--warn-execstack -Wl,-z,relro', 'target_arch': 'x86_64', 'target_os': 'linux'}, 'bits': 64, 'debug': False, 'maxBsonObjectSize': 16777216, 'storageEngines': ['devnull', 'ephemeralForTest', 'mmapv1', 'wiredTiger'], 'ok': 1.0}""") diff --git a/setup.py b/setup.py index fc6e4da..9a05708 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ """ import re +from ast import literal_eval from codecs import open # To use a consistent encoding from os import path @@ -27,7 +28,7 @@ with open(path.join(here, name + '.py'), encoding='utf-8') as f: data = f.read() - version = eval(re.search("__version__[ ]*=[ ]*([^\r\n]+)", data).group(1)) + version = literal_eval(re.search("__version__[ ]*=[ ]*([^\r\n]+)", data).group(1)) setup( From e515f4474f85c57dababe020ed6bbc9b31056530 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 21:47:03 -0500 Subject: [PATCH 11/46] use secrets over random when possible --- sshtunnel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index eb60030..304c36c 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -15,7 +15,6 @@ import getpass import logging import os -import random import socket import string import sys @@ -24,6 +23,11 @@ from binascii import hexlify from select import select +try: + import secrets as random +except ImportError: + import random + import paramiko if sys.version_info[0] < 3: # pragma: no cover From ba72a95de5e5d12112e12c41600afa6573603bb6 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 22:04:41 -0500 Subject: [PATCH 12/46] when possible Do not catch blind exceptions --- sshtunnel.py | 8 ++++---- tests/test_forwarder.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 304c36c..2cb8f9d 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -360,9 +360,9 @@ def handle(self): src_addr=src_address, timeout=TUNNEL_TIMEOUT ) - except Exception as e: # pragma: no cover - msg_tupe = 'ssh ' if isinstance(e, paramiko.SSHException) else '' - exc_msg = 'open new channel {0}error: {1}'.format(msg_tupe, e) + except (paramiko.SSHException, EnvironmentError) as e: # pragma: no cover + type_msg = 'ssh ' if isinstance(e, paramiko.SSHException) else '' + exc_msg = 'open new channel {0}error: {1}'.format(type_msg, e) log_msg = '{0} {1}'.format(self.info, exc_msg) self.logger.log(TRACE_LEVEL, log_msg) raise HandlerSSHTunnelForwarderError(exc_msg) @@ -1302,7 +1302,7 @@ def _stop_transport(self, force=False): if isinstance(_srv, _StreamForwardServer): try: os.unlink(_srv.local_address) - except Exception as e: + except OSError as e: self.logger.error('Unable to unlink socket {0}: {1}' .format(_srv.local_address, repr(e))) self.is_alive = False diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 9937a82..2b2ff04 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -401,7 +401,7 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): s.close() socks.remove(s) self.log.info('<<< echo-server received STOP signal') - except Exception as e: + except AttributeError as e: self.log.info('echo-server got Exception: {0}'.format(repr(e))) finally: self.is_server_working = False From f0dfebd2e0251093c92a8fee9e1c92f5e1f121a7 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 22:09:41 -0500 Subject: [PATCH 13/46] don't use boolean positional values in function calls --- sshtunnel.py | 6 +++--- tests/test_forwarder.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 2cb8f9d..05015ba 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -249,7 +249,7 @@ def create_logger(logger=None, _check_paramiko_handlers(logger=logger) if capture_warnings and sys.version_info >= (2, 7): - logging.captureWarnings(True) + logging.captureWarnings(capture=True) pywarnings = logging.getLogger('py.warnings') pywarnings.handlers.extend(logger.handlers) return logger @@ -406,9 +406,9 @@ def handle_error(self, request, client_address): 'to remote {1} side of the tunnel: {2}' .format(local_side, remote_side, exc)) try: - self.tunnel_ok.put(False, block=False, timeout=0.1) + self.tunnel_ok.put(item=False, block=False, timeout=0.1) except queue.Full: - # wait untill tunnel_ok.get is called + # wait until tunnel_ok.get is called pass except exc: self.logger.error('unexpected internal error: {0}'.format(exc)) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 2b2ff04..737dcb1 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1141,13 +1141,13 @@ def test_make_ssh_forward_server_sets_daemon_true(self): """ Test `make_ssh_forward_server` respects `daemon_forward_servers=True` """ - self.check_make_ssh_forward_server_sets_daemon(True) + self.check_make_ssh_forward_server_sets_daemon(case=True) def test_make_ssh_forward_server_sets_daemon_false(self): """ Test `make_ssh_forward_server` respects `daemon_forward_servers=False` """ - self.check_make_ssh_forward_server_sets_daemon(False) + self.check_make_ssh_forward_server_sets_daemon(case=False) def test_get_keys(self): """ Test loading keys from the paramiko Agent """ From 952132cf9e158b62fe09d4f4c913501356713521 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 22:12:46 -0500 Subject: [PATCH 14/46] don't use string literals in exceptions --- sshtunnel.py | 80 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 24 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 05015ba..8422b26 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -116,16 +116,23 @@ def check_address(address): check_port(address[1]) elif isinstance(address, string_types): if os.name != 'posix': - raise ValueError('Platform does not support UNIX domain sockets') + msg = 'Platform does not support UNIX domain sockets' + raise ValueError(msg) if not ( os.path.exists(address) or os.access(os.path.dirname(address), os.W_OK) ): - raise ValueError('ADDRESS not a valid socket domain socket ({0})' - .format(address)) + msg = ( + 'ADDRESS not a valid socket domain socket ({0})' + .format(address) + ) + raise ValueError(msg) else: - raise TypeError('ADDRESS is not a tuple, string, or character buffer ' - '({0})'.format(type(address).__name__)) + msg = ( + 'ADDRESS is not a tuple, string, or character buffer ' + '({0})'.format(type(address).__name__) + ) + raise TypeError(msg) def check_addresses(address_list, is_remote=False): @@ -156,8 +163,11 @@ def check_addresses(address_list, is_remote=False): """ assert all(isinstance(x, (tuple, string_types)) for x in address_list) if (is_remote and any(isinstance(x, string_types) for x in address_list)): - raise AssertionError('UNIX domain sockets not allowed for remote' - 'addresses') + msg = ( + 'UNIX domain sockets not allowed for remote' + 'addresses' + ) + raise AssertionError(msg) for address in address_list: check_address(address) @@ -825,8 +835,11 @@ def _consolidate_binds(local_binds, remote_binds): """ count = len(remote_binds) - len(local_binds) if count < 0: - raise ValueError('Too many local bind addresses ' - '(local_bind_addresses > remote_bind_addresses)') + msg = ( + 'Too many local bind addresses ' + '(local_bind_addresses > remote_bind_addresses)' + ) + raise ValueError(msg) local_binds.extend([('0.0.0.0', 0) for x in range(count)]) return local_binds @@ -866,7 +879,8 @@ def _consolidate_auth(ssh_password=None, ssh_loaded_pkeys.insert(0, ssh_pkey) if not ssh_password and not ssh_loaded_pkeys: - raise ValueError('No password or public key available!') + msg = 'No password or public key available!' + raise ValueError(msg) return (ssh_password, ssh_loaded_pkeys) @staticmethod @@ -875,15 +889,21 @@ def _get_binds(bind_address, bind_addresses, is_remote=False): if not bind_address and not bind_addresses: if is_remote: - raise ValueError("No {0} bind addresses specified. Use " + msg = ( + "No {0} bind addresses specified. Use " "'{0}_bind_address' or '{0}_bind_addresses'" - " argument".format(addr_kind)) + " argument".format(addr_kind) + ) + raise ValueError(msg) else: return [] elif bind_address and bind_addresses: - raise ValueError("You can't use both '{0}_bind_address' and " + msg = ( + "You can't use both '{0}_bind_address' and " "'{0}_bind_addresses' arguments. Use one of " - "them.".format(addr_kind)) + "them.".format(addr_kind) + ) + raise ValueError(msg) if bind_address: bind_addresses = [bind_address] if not is_remote: @@ -900,18 +920,24 @@ def _process_deprecated(attrib, deprecated_attrib, kwargs): Processes optional deprecate arguments """ if deprecated_attrib not in _DEPRECATIONS: - raise ValueError('{0} not included in deprecations list' - .format(deprecated_attrib)) + msg = ( + '{0} not included in deprecations list' + .format(deprecated_attrib) + ) + raise ValueError(msg) if deprecated_attrib in kwargs: warnings.warn("'{0}' is DEPRECATED use '{1}' instead" .format(deprecated_attrib, _DEPRECATIONS[deprecated_attrib]), DeprecationWarning) if attrib: - raise ValueError("You can't use both '{0}' and '{1}'. " + msg = ( + "You can't use both '{0}' and '{1}'. " "Please only use one of them" .format(deprecated_attrib, - _DEPRECATIONS[deprecated_attrib])) + _DEPRECATIONS[deprecated_attrib]) + ) + raise ValueError(msg) else: return kwargs.pop(deprecated_attrib) return attrib @@ -972,7 +998,8 @@ def __init__( ssh_port = kwargs.pop('ssh_port', None) if kwargs: - raise ValueError('Unknown arguments: {0}'.format(kwargs)) + msg = 'Unknown arguments: {0}'.format(kwargs) + raise ValueError(msg) # remote binds self._remote_binds = self._get_binds(remote_bind_address, @@ -1505,8 +1532,9 @@ def local_bind_port(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: + msg = 'Use .local_bind_ports property for more than one tunnel' raise BaseSSHTunnelForwarderError( - 'Use .local_bind_ports property for more than one tunnel' + msg ) return self.local_bind_ports[0] @@ -1515,8 +1543,9 @@ def local_bind_host(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: + msg = 'Use .local_bind_hosts property for more than one tunnel' raise BaseSSHTunnelForwarderError( - 'Use .local_bind_hosts property for more than one tunnel' + msg ) return self.local_bind_hosts[0] @@ -1525,8 +1554,9 @@ def local_bind_address(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: + msg = 'Use .local_bind_addresses property for more than one tunnel' raise BaseSSHTunnelForwarderError( - 'Use .local_bind_addresses property for more than one tunnel' + msg ) return self.local_bind_addresses[0] @@ -1724,11 +1754,13 @@ def _bindlist(input_str): _port = '22' # default port if not given return _ip, int(_port) except ValueError: + msg = 'Address tuple must be of type IP_ADDRESS:PORT' raise argparse.ArgumentTypeError( - 'Address tuple must be of type IP_ADDRESS:PORT' + msg ) except AssertionError: - raise argparse.ArgumentTypeError("Both IP:PORT can't be missing!") + msg = "Both IP:PORT can't be missing!" + raise argparse.ArgumentTypeError(msg) def _parse_arguments(args=None): From afa868a0c79333fd005956fb93c86d91761cf488 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 22:18:17 -0500 Subject: [PATCH 15/46] remove unnecessary pragma: no cover --- sshtunnel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 8422b26..cc2c7cb 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -30,13 +30,13 @@ import paramiko -if sys.version_info[0] < 3: # pragma: no cover +if sys.version_info[0] < 3: import Queue as queue import SocketServer as socketserver string_types = basestring # noqa: F821 undefined name input_ = raw_input # noqa: F821 undefined name -else: # pragma: no cover +else: import queue import socketserver string_types = str @@ -370,7 +370,7 @@ def handle(self): src_addr=src_address, timeout=TUNNEL_TIMEOUT ) - except (paramiko.SSHException, EnvironmentError) as e: # pragma: no cover + except (paramiko.SSHException, EnvironmentError) as e: type_msg = 'ssh ' if isinstance(e, paramiko.SSHException) else '' exc_msg = 'open new channel {0}error: {1}'.format(type_msg, e) log_msg = '{0} {1}'.format(self.info, exc_msg) From a506cdafc1fca7102102accbb4cf1c7f34a8f7d8 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 22:18:57 -0500 Subject: [PATCH 16/46] continuation line indentation and line length --- setup.py | 4 +++- sshtunnel.py | 31 +++++++++++++++---------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 9a05708..45fe2ae 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,9 @@ with open(path.join(here, name + '.py'), encoding='utf-8') as f: data = f.read() - version = literal_eval(re.search("__version__[ ]*=[ ]*([^\r\n]+)", data).group(1)) + version = literal_eval( + re.search("__version__[ ]*=[ ]*([^\r\n]+)", data).group(1) + ) setup( diff --git a/sshtunnel.py b/sshtunnel.py index cc2c7cb..8d5813d 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -124,13 +124,13 @@ def check_address(address): ): msg = ( 'ADDRESS not a valid socket domain socket ({0})' - .format(address) + .format(address) ) raise ValueError(msg) else: msg = ( - 'ADDRESS is not a tuple, string, or character buffer ' - '({0})'.format(type(address).__name__) + 'ADDRESS is not a tuple, string, or character buffer ({0})' + .format(type(address).__name__) ) raise TypeError(msg) @@ -163,10 +163,7 @@ def check_addresses(address_list, is_remote=False): """ assert all(isinstance(x, (tuple, string_types)) for x in address_list) if (is_remote and any(isinstance(x, string_types) for x in address_list)): - msg = ( - 'UNIX domain sockets not allowed for remote' - 'addresses' - ) + msg = 'UNIX domain sockets not allowed for remote addresses' raise AssertionError(msg) for address in address_list: @@ -837,7 +834,7 @@ def _consolidate_binds(local_binds, remote_binds): if count < 0: msg = ( 'Too many local bind addresses ' - '(local_bind_addresses > remote_bind_addresses)' + '(local_bind_addresses > remote_bind_addresses)' ) raise ValueError(msg) local_binds.extend([('0.0.0.0', 0) for x in range(count)]) @@ -891,8 +888,8 @@ def _get_binds(bind_address, bind_addresses, is_remote=False): if is_remote: msg = ( "No {0} bind addresses specified. Use " - "'{0}_bind_address' or '{0}_bind_addresses'" - " argument".format(addr_kind) + "'{0}_bind_address' or '{0}_bind_addresses'" + " argument".format(addr_kind) ) raise ValueError(msg) else: @@ -900,8 +897,8 @@ def _get_binds(bind_address, bind_addresses, is_remote=False): elif bind_address and bind_addresses: msg = ( "You can't use both '{0}_bind_address' and " - "'{0}_bind_addresses' arguments. Use one of " - "them.".format(addr_kind) + "'{0}_bind_addresses' arguments. Use one of " + "them.".format(addr_kind) ) raise ValueError(msg) if bind_address: @@ -922,7 +919,7 @@ def _process_deprecated(attrib, deprecated_attrib, kwargs): if deprecated_attrib not in _DEPRECATIONS: msg = ( '{0} not included in deprecations list' - .format(deprecated_attrib) + .format(deprecated_attrib) ) raise ValueError(msg) if deprecated_attrib in kwargs: @@ -933,9 +930,11 @@ def _process_deprecated(attrib, deprecated_attrib, kwargs): if attrib: msg = ( "You can't use both '{0}' and '{1}'. " - "Please only use one of them" - .format(deprecated_attrib, - _DEPRECATIONS[deprecated_attrib]) + "Please only use one of them" + .format( + deprecated_attrib, + _DEPRECATIONS[deprecated_attrib] + ) ) raise ValueError(msg) else: From f11475ecfd833a59395d19145a9a59669785e521 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 22:20:50 -0500 Subject: [PATCH 17/46] don't implicitly concatenate string literals accross lines --- sshtunnel.py | 6 ++++-- tests/test_forwarder.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 8d5813d..11b2779 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -176,8 +176,10 @@ def _add_handler(logger, handler=None, loglevel=None): """ handler.setLevel(loglevel or DEFAULT_LOGLEVEL) if handler.level <= logging.DEBUG: - _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' \ - '%(lineno)04d@%(module)-10.9s| %(message)s' + _fmt = ( + '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' + '%(lineno)04d@%(module)-10.9s| %(message)s' + ) handler.setFormatter(logging.Formatter(_fmt)) else: handler.setFormatter(logging.Formatter( diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 737dcb1..c6da5c0 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -213,8 +213,10 @@ def setUpClass(cls): cls.log.addHandler(cls._sshtunnel_log_handler) cls.sshtunnel_log_messages = cls._sshtunnel_log_handler.messages # set verbose format for logging - _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' \ - '%(lineno)04d@%(module)-10.9s| %(message)s' + _fmt = ( + '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' + '%(lineno)04d@%(module)-10.9s| %(message)s' + ) for handler in cls.log.handlers: handler.setFormatter(logging.Formatter(_fmt)) From e8bd1bd920d5825e1628bc8522218af5e4df2f26 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:04:29 -0500 Subject: [PATCH 18/46] logger messages should use lazy concatenation --- sshtunnel.py | 188 +++++++++++++++++++++++++--------------- tests/test_forwarder.py | 89 ++++++++++--------- 2 files changed, 169 insertions(+), 108 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 11b2779..6addc53 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -324,24 +324,23 @@ def _redirect(self, chan): if not data: self.logger.log( TRACE_LEVEL, - '>>> OUT {0} recv empty data >>>'.format(self.info) + '>>> OUT %s recv empty data >>>', self.info ) break if self.logger.isEnabledFor(TRACE_LEVEL): self.logger.log( TRACE_LEVEL, - '>>> OUT {0} send to {1}: {2} >>>'.format( - self.info, - self.remote_address, - hexlify(data) - ) + '>>> OUT %s send to %s: %s >>>', + self.info, + self.remote_address, + hexlify(data) ) chan.sendall(data) if chan in rqst: # else if not chan.recv_ready(): self.logger.log( TRACE_LEVEL, - '<<< IN {0} recv is not ready <<<'.format(self.info) + '<<< IN %s recv is not ready <<<', self.info ) break data = chan.recv(16384) @@ -349,7 +348,7 @@ def _redirect(self, chan): hex_data = hexlify(data) self.logger.log( TRACE_LEVEL, - '<<< IN {0} recv: {1} <<<'.format(self.info, hex_data) + '<<< IN %s recv: %s <<<', self.info, hex_data ) self.request.sendall(data) @@ -372,11 +371,14 @@ def handle(self): except (paramiko.SSHException, EnvironmentError) as e: type_msg = 'ssh ' if isinstance(e, paramiko.SSHException) else '' exc_msg = 'open new channel {0}error: {1}'.format(type_msg, e) - log_msg = '{0} {1}'.format(self.info, exc_msg) - self.logger.log(TRACE_LEVEL, log_msg) + self.logger.log(TRACE_LEVEL, '%s %s', self.info, exc_msg) raise HandlerSSHTunnelForwarderError(exc_msg) - self.logger.log(TRACE_LEVEL, '{0} connected'.format(self.info)) + self.logger.log( + TRACE_LEVEL, + '%s connected', + self.info + ) try: self._redirect(chan) except socket.error: @@ -384,15 +386,22 @@ def handle(self): # exception. It was seen that a 3way FIN is processed later on, so # no need to make an ordered close of the connection here or raise # the exception beyond this point... - self.logger.log(TRACE_LEVEL, '{0} sending RST'.format(self.info)) + self.logger.log(TRACE_LEVEL, '%s sending RST', self.info) except Exception as e: - self.logger.log(TRACE_LEVEL, - '{0} error: {1}'.format(self.info, repr(e))) + self.logger.log( + TRACE_LEVEL, + '%s error: %s', + self.info, + repr(e) + ) finally: chan.close() self.request.close() - self.logger.log(TRACE_LEVEL, - '{0} connection closed.'.format(self.info)) + self.logger.log( + TRACE_LEVEL, + '%s connection closed.', + self.info + ) class _ForwardServer(socketserver.TCPServer): # Not Threading @@ -411,16 +420,22 @@ def handle_error(self, request, client_address): (exc_class, exc, tb) = sys.exc_info() local_side = request.getsockname() remote_side = self.remote_address - self.logger.error('Could not establish connection from local {0} ' - 'to remote {1} side of the tunnel: {2}' - .format(local_side, remote_side, exc)) + self.logger.error( + '%s %s %s %s %s: %s', + 'Could not establish connection from local', + local_side, + 'to remote', + remote_side, + 'side of the tunnel', + exc + ) try: self.tunnel_ok.put(item=False, block=False, timeout=0.1) except queue.Full: # wait until tunnel_ok.get is called pass except exc: - self.logger.error('unexpected internal error: {0}'.format(exc)) + self.logger.error('unexpected internal error: %s', exc) @property def local_address(self): @@ -812,8 +827,8 @@ def _read_ssh_config(ssh_host, except IOError: if logger: logger.warning( - 'Could not read SSH configuration file: {0}' - .format(ssh_config_file) + 'Could not read SSH configuration file: %s', + ssh_config_file ) except (AttributeError, TypeError): # ssh_config_file is None if logger: @@ -872,8 +887,10 @@ def _consolidate_auth(ssh_password=None, logger=logger ) elif logger: - logger.warning('Private key file not found: {0}' - .format(ssh_pkey)) + logger.warning( + 'Private key file not found: %s', + ssh_pkey + ) if isinstance(ssh_pkey, paramiko.pkey.PKey): ssh_loaded_pkeys.insert(0, ssh_pkey) @@ -1040,13 +1057,17 @@ def __init__( check_host(self.ssh_host) check_port(self.ssh_port) - self.logger.info("Connecting to gateway: {0}:{1} as user '{2}'" - .format(self.ssh_host, - self.ssh_port, - self.ssh_username)) + self.logger.info( + "Connecting to gateway: %s:%s as user '%s'", + self.ssh_host, + self.ssh_port, + self.ssh_username + ) - self.logger.debug('Concurrent connections allowed: {0}' - .format(self._threaded)) + self.logger.debug( + 'Concurrent connections allowed: %s', + self._threaded + ) def __del__(self): if self.is_active or self.is_alive: @@ -1088,7 +1109,7 @@ def _check_tunnel(self, _srv): if self.skip_tunnel_checkup: self.tunnel_is_up[_srv.local_address] = True return - self.logger.info('Checking tunnel to: {0}'.format(_srv.remote_address)) + self.logger.info('Checking tunnel to: %s', _srv.remote_address) if isinstance(_srv.local_address, string_types): # UNIX stream s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) else: @@ -1103,17 +1124,17 @@ def _check_tunnel(self, _srv): timeout=TUNNEL_TIMEOUT * 1.1 ) self.logger.debug( - 'Tunnel to {0} is DOWN'.format(_srv.remote_address) + 'Tunnel to %s is DOWN', _srv.remote_address ) except socket.error: self.logger.debug( - 'Tunnel to {0} is DOWN'.format(_srv.remote_address) + 'Tunnel to %s is DOWN', _srv.remote_address ) self.tunnel_is_up[_srv.local_address] = False except queue.Empty: self.logger.debug( - 'Tunnel to {0} is UP'.format(_srv.remote_address) + 'Tunnel to %s is UP', _srv.remote_address ) self.tunnel_is_up[_srv.local_address] = True finally: @@ -1207,7 +1228,7 @@ def get_agent_keys(logger=None): paramiko_agent = paramiko.Agent() agent_keys = paramiko_agent.get_keys() if logger: - logger.info('{0} keys loaded from agent'.format(len(agent_keys))) + logger.info('%s keys loaded from agent', len(agent_keys)) return list(agent_keys) @staticmethod @@ -1261,10 +1282,13 @@ def get_keys( # noqa: C901 too complex keys.append(ssh_pkey) except OSError as exc: if logger: - logger.warning('Private key file {0} check error: {1}' - .format(ssh_pkey_expanded, exc)) + logger.warning( + 'Private key file %s check error: %s', + ssh_pkey_expanded, + exc + ) if logger: - logger.info('{0} key(s) loaded'.format(len(keys))) + logger.info('%s key(s) loaded', len(keys)) return keys def _get_transport(self): @@ -1274,7 +1298,7 @@ def _get_transport(self): proxy_repr = repr(self.ssh_proxy.cmd[1]) else: proxy_repr = repr(self.ssh_proxy) - self.logger.debug('Connecting via proxy: {0}'.format(proxy_repr)) + self.logger.debug('Connecting via proxy: %s', proxy_repr) _socket = self.ssh_proxy else: _socket = (self.ssh_host, self.ssh_port) @@ -1293,8 +1317,11 @@ def _get_transport(self): if isinstance(sock, socket.socket): sock_timeout = sock.gettimeout() sock_info = repr((sock.family, sock.type, sock.proto)) - self.logger.debug('Transport socket info: {0}, timeout={1}' - .format(sock_info, sock_timeout)) + self.logger.debug( + 'Transport socket info: %s, timeout=%s', + sock_info, + sock_timeout + ) return transport def _check_is_started(self): @@ -1319,11 +1346,12 @@ def _stop_transport(self, force=False): self._transport.stop_thread() for _srv in self._server_list: status = 'up' if self.tunnel_is_up[_srv.local_address] else 'down' - self.logger.info('Shutting down tunnel: {0} <> {1} ({2})'.format( + self.logger.info( + 'Shutting down tunnel: %s <> %s (%s)', address_to_str(_srv.local_address), address_to_str(_srv.remote_address), status - )) + ) _srv.shutdown() _srv.server_close() # clean up the UNIX domain socket if we're using one @@ -1331,8 +1359,11 @@ def _stop_transport(self, force=False): try: os.unlink(_srv.local_address) except OSError as e: - self.logger.error('Unable to unlink socket {0}: {1}' - .format(_srv.local_address, repr(e))) + self.logger.error( + 'Unable to unlink socket %s: %s', + _srv.local_address, + repr(e) + ) self.is_alive = False if self.is_active: self.logger.info('Closing ssh transport') @@ -1348,8 +1379,10 @@ def _connect_to_gateway(self): - As last resort, try with a provided password """ for key in self.ssh_pkeys: - self.logger.debug('Trying to log in with key: {0}' - .format(hexlify(key.get_fingerprint()))) + self.logger.debug( + 'Trying to log in with key: %s', + hexlify(key.get_fingerprint()) + ) try: self._transport = self._get_transport() self._transport.connect(hostkey=self.ssh_host_key, @@ -1362,8 +1395,10 @@ def _connect_to_gateway(self): self._stop_transport() if self.ssh_password: # avoid conflict using both pass and pkey - self.logger.debug('Trying to log in with password: {0}' - .format('*' * len(self.ssh_password))) + self.logger.debug( + 'Trying to log in with password: %s', + '*' * len(self.ssh_password) + ) try: self._transport = self._get_transport() self._transport.connect(hostkey=self.ssh_host_key, @@ -1385,21 +1420,27 @@ def _create_tunnels(self): try: self._connect_to_gateway() except socket.gaierror: # raised by paramiko.Transport - msg = 'Could not resolve IP address for {0}, aborting!' \ - .format(self.ssh_host) - self.logger.error(msg) + self.logger.error( + 'Could not resolve IP address for %s, aborting!', + self.ssh_host + ) return except (paramiko.SSHException, socket.error) as e: - template = 'Could not connect to gateway {0}:{1} : {2}' - msg = template.format(self.ssh_host, self.ssh_port, e.args[0]) - self.logger.error(msg) + self.logger.error( + 'Could not connect to gateway %s:%s : %s', + self.ssh_host, + self.ssh_port, + e.args[0] + ) return for (rem, loc) in zip(self._remote_binds, self._local_binds): try: self._make_ssh_forward_server(rem, loc) except BaseSSHTunnelForwarderError as e: - msg = 'Problem setting SSH Forwarder up: {0}'.format(e.value) - self.logger.error(msg) + self.logger.error( + 'Problem setting SSH Forwarder up: %s', + e.value + ) @staticmethod def read_private_key_file(pkey_file, @@ -1430,34 +1471,43 @@ def read_private_key_file(pkey_file, password=pkey_password ) if logger: - logger.debug('Private key file ({0}, {1}) successfully ' - 'loaded'.format(pkey_file, pkey_class)) + logger.debug( + 'Private key file (%s, %s) successfully loaded', + pkey_file, + pkey_class + ) break except paramiko.PasswordRequiredException: if logger: - logger.error('Password is required for key {0}' - .format(pkey_file)) + logger.error('Password is required for key %s', pkey_file) break except paramiko.SSHException: if logger: - logger.debug('Private key file ({0}) could not be loaded ' - 'as type {1} or bad password' - .format(pkey_file, pkey_class)) + logger.debug( + '%s (%s) %s %s %s', + 'Private key file', + pkey_file, + 'could not be loaded as type', + pkey_class, + 'or bad password' + ) return ssh_pkey def _serve_forever_wrapper(self, _srv, poll_interval=0.1): """ Wrapper for the server created for a SSH forward """ - self.logger.info('Opening tunnel: {0} <> {1}'.format( + self.logger.info( + 'Opening tunnel: %s <> %s', address_to_str(_srv.local_address), - address_to_str(_srv.remote_address)) + address_to_str(_srv.remote_address) ) _srv.serve_forever(poll_interval) # blocks until finished - self.logger.info('Tunnel: {0} <> {1} released'.format( + self.logger.info( + 'Tunnel: %s <> %s released', address_to_str(_srv.local_address), - address_to_str(_srv.remote_address)) + address_to_str(_srv.remote_address) ) def start(self): @@ -1514,7 +1564,7 @@ def stop(self, force=False): opened_address_text = ', '.join( (address_to_str(k.local_address) for k in self._server_list) ) or 'None' - self.logger.debug('Listening tunnels: ' + opened_address_text) + self.logger.debug('Listening tunnels: %s', opened_address_text) self._stop_transport(force=force) self._server_list = [] # reset server list self.tunnel_is_up = {} # reset tunnel status diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index c6da5c0..76bc9ec 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -140,22 +140,27 @@ def __init__(self, *args, **kwargs): super(NullServer, self).__init__(*args, **kwargs) def check_channel_forward_agent_request(self, channel): - self.log.debug('NullServer.check_channel_forward_agent_request() {0}' - .format(channel)) + self.log.debug( + 'NullServer.check_channel_forward_agent_request() %s', channel + ) return False def get_allowed_auths(self, username): allowed_auths = 'publickey{0}'.format( ',password' if username == SSH_USERNAME else '' ) - self.log.debug('NullServer >> allowed auths for {0}: {1}' - .format(username, allowed_auths)) + self.log.debug( + 'NullServer >> allowed auths for %s: %s', username, allowed_auths + ) return allowed_auths def check_auth_password(self, username, password): _ok = (username == SSH_USERNAME and password == SSH_PASSWORD) - self.log.debug('NullServer >> password for {0} {1}OK' - .format(username, '' if _ok else 'NOT-')) + self.log.debug( + 'NullServer >> password for %s %sOK', + username, + '' if _ok else 'NOT-' + ) return paramiko.AUTH_SUCCESSFUL if _ok else paramiko.AUTH_FAILED def check_auth_publickey(self, username, key): @@ -167,8 +172,11 @@ def check_auth_publickey(self, username, key): ) except KeyError: _ok = False - self.log.debug('NullServer >> pkey authentication for {0} {1}OK' - .format(username, '' if _ok else 'NOT-')) + self.log.debug( + 'NullServer >> pkey authentication for %s %sOK', + username, + '' if _ok else 'NOT-' + ) return paramiko.AUTH_SUCCESSFUL if _ok else paramiko.AUTH_FAILED def check_channel_request(self, kind, chanid): @@ -188,9 +196,13 @@ def check_global_request(self, kind, msg): return True def check_channel_direct_tcpip_request(self, chanid, origin, destination): - self.log.debug('NullServer.check_channel_direct_tcpip_request' - '(chanid={0}) {1} -> {2}' - .format(chanid, origin, destination)) + self.log.debug( + 'NullServer.%s (chanid=%s) %s -> %s', + 'check_channel_direct_tcpip_request', + chanid, + origin, + destination + ) return paramiko.OPEN_SUCCEEDED @@ -223,13 +235,11 @@ def setUpClass(cls): def setUp(self): super(SSHClientTest, self).setUp() self.log.debug('*' * 80) - self.log.info('setUp for: {0}()'.format(self._testMethodName.upper())) + self.log.info('setUp for: %s()', self._testMethodName.upper()) self.ssockl, self.saddr, self.sport = self.make_socket() self.esockl, self.eaddr, self.eport = self.make_socket() - self.log.info("Socket for ssh-server: {0}:{1}" - .format(self.saddr, self.sport)) - self.log.info("Socket for echo-server: {0}:{1}" - .format(self.eaddr, self.eport)) + self.log.info("Socket for ssh-server: %s:%s", self.saddr, self.sport) + self.log.info("Socket for echo-server: %s:%s", self.eaddr, self.eport) self.ssh_event = threading.Event() self.running_threads = [] @@ -239,14 +249,15 @@ def setUp(self): self._sshtunnel_log_handler.reset() def tearDown(self): - self.log.info('tearDown for: {0}()' - .format(self._testMethodName.upper())) + self.log.info('tearDown for: %s()', self._testMethodName.upper()) self.stop_echo_and_ssh_server() for thread in self.running_threads: x = self.threads[thread] - self.log.info('thread {0} ({1})' - .format(thread, - 'alive' if x.is_alive() else 'defunct')) + self.log.info( + 'thread %s (%s)', + thread, + 'alive' if x.is_alive() else 'defunct' + ) while self.running_threads: for thread in self.running_threads: @@ -254,18 +265,18 @@ def tearDown(self): self.wait_for_thread(self.threads[thread], who='tearDown') if not x.is_alive(): - self.log.info('thread {0} now stopped'.format(thread)) + self.log.info('thread %s now stopped', thread) for attr in ['server', 'tc', 'ts', 'socks', 'ssockl', 'esockl']: if hasattr(self, attr): - self.log.info('tearDown() {0}'.format(attr)) + self.log.info('tearDown() %s', attr) getattr(self, attr).close() def wait_for_thread(self, thread, timeout=THREADS_TIMEOUT, who=None): if thread.is_alive(): - self.log.debug('{0}waiting for {1} to end...' - .format('{0} '.format(who) if who else '', - thread.name)) + self.log.debug( + '%s waiting for %s to end...', who or '', thread.name + ) thread.join(timeout) def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): @@ -274,7 +285,7 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): try: schan = self.ts.accept(timeout=timeout) info = "forward-server schan <> echo" - self.log.info(info + " accept()") + self.log.info("%s accept()", info) echo = socket.create_connection( (self.eaddr, self.eport) ) @@ -285,28 +296,28 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): timeout) if schan in rqst: data = schan.recv(1024) - self.log.debug('{0} -->: {1}'.format(info, repr(data))) + self.log.debug('%s -->: %s', info, repr(data)) echo.send(data) if len(data) == 0: break if echo in rqst: data = echo.recv(1024) - self.log.debug('{0} <--: {1}'.format(info, repr(data))) + self.log.debug('%s <--: %s', info, repr(data)) schan.send(data) if len(data) == 0: break self.log.info('<<< forward-server received STOP signal') except socket.error: - self.log.critical('{0} sending RST'.format(info)) + self.log.critical('%s sending RST', info) # except Exception as e: # # we reach this point usually when schan is None (paramiko bug?) # self.log.critical(repr(e)) finally: if schan: - self.log.debug('{0} closing connection...'.format(info)) + self.log.debug('%s closing connection...', info) schan.close() echo.close() - self.log.debug('{0} connection closed.'.format(info)) + self.log.debug('%s connection closed.', info) def _run_ssh_server(self): self.log.info('ssh-server Start') @@ -383,8 +394,7 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): # handle the server socket try: client, address = self.esockl.accept() - self.log.info('echo-server accept() {0}' - .format(address)) + self.log.info('echo-server accept() %s', address) except OSError: self.log.info('echo-server accept() OSError') break @@ -393,8 +403,7 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): # handle all other sockets try: data = s.recv(1000) - self.log.info('echo-server echoing {0}' - .format(data)) + self.log.info('echo-server echoing %s', data) s.send(data) except OSError: self.log.warning('echo-server OSError') @@ -404,7 +413,7 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): socks.remove(s) self.log.info('<<< echo-server received STOP signal') except AttributeError as e: - self.log.info('echo-server got Exception: {0}'.format(repr(e))) + self.log.info('echo-server got Exception: %s', repr(e)) finally: self.is_server_working = False if 'forward-server' in self.threads: @@ -439,8 +448,10 @@ def test_echo_server(self): local_bind_addr = ('127.0.0.1', server.local_bind_port) self.log.info('_test_server(): try connect!') s = socket.create_connection(local_bind_addr) - self.log.info('_test_server(): connected from {0}! try send!' - .format(s.getsockname())) + self.log.info( + '_test_server(): connected from %s! try send!', + s.getsockname() + ) s.send(message) self.log.info('_test_server(): sent!') z = (s.recv(1000)) From 3d8da6386d9643e39f7c4869ce2d352f5b50fca5 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:13:22 -0500 Subject: [PATCH 19/46] remove unnecessary pass statement --- sshtunnel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index 6addc53..c5662e0 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -299,7 +299,6 @@ def __str__(self): class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError): """ Exception for Tunnel forwarder errors """ - pass ######################## From 57bf41a9ca8e9e28ae9d777e819e18480e780b82 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:15:15 -0500 Subject: [PATCH 20/46] Remove unnecessary `elif` statements --- e2e_tests/run_docker_e2e_db_tests.py | 4 ++-- sshtunnel.py | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/e2e_tests/run_docker_e2e_db_tests.py b/e2e_tests/run_docker_e2e_db_tests.py index a2ca6ab..ee38004 100644 --- a/e2e_tests/run_docker_e2e_db_tests.py +++ b/e2e_tests/run_docker_e2e_db_tests.py @@ -58,7 +58,7 @@ def wait(conn): state = conn.poll() if state == psycopg2.extensions.POLL_OK: break - elif state == psycopg2.extensions.POLL_WRITE: + if state == psycopg2.extensions.POLL_WRITE: select.select([], [conn.fileno()], []) elif state == psycopg2.extensions.POLL_READ: select.select([conn.fileno()], [], []) @@ -71,7 +71,7 @@ def wait_timeout(conn): state = conn.poll() if state == psycopg2.extensions.POLL_OK: return ASYNC_OK - elif state == psycopg2.extensions.POLL_WRITE: + if state == psycopg2.extensions.POLL_WRITE: # Wait for the given time and then check the return status # If three empty lists are returned then the time-out is # reached. diff --git a/sshtunnel.py b/sshtunnel.py index c5662e0..7c6c67d 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -910,9 +910,8 @@ def _get_binds(bind_address, bind_addresses, is_remote=False): " argument".format(addr_kind) ) raise ValueError(msg) - else: - return [] - elif bind_address and bind_addresses: + return [] + if bind_address and bind_addresses: msg = ( "You can't use both '{0}_bind_address' and " "'{0}_bind_addresses' arguments. Use one of " @@ -955,8 +954,7 @@ def _process_deprecated(attrib, deprecated_attrib, kwargs): ) ) raise ValueError(msg) - else: - return kwargs.pop(deprecated_attrib) + return kwargs.pop(deprecated_attrib) return attrib def __init__( @@ -1173,8 +1171,7 @@ def _make_stream_ssh_forward_server_class(self, remote_address_): def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None): if self._raise_fwd_exc: raise exception(reason) - else: - self.logger.error(repr(exception(reason))) + self.logger.error(repr(exception(reason))) def _make_ssh_forward_server(self, remote_address, local_bind_address): """ @@ -1800,7 +1797,7 @@ def _bindlist(input_str): (_ip, _port) = ip_port if not _ip and not _port: raise AssertionError - elif not _port: + if not _port: _port = '22' # default port if not given return _ip, int(_port) except ValueError: From 6e0271e1660de3cdd7831da53bd59fb52e5cbffc Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:17:01 -0500 Subject: [PATCH 21/46] Remove unnecessary assignment --- e2e_tests/run_docker_e2e_db_tests.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/e2e_tests/run_docker_e2e_db_tests.py b/e2e_tests/run_docker_e2e_db_tests.py index ee38004..e5d441a 100644 --- a/e2e_tests/run_docker_e2e_db_tests.py +++ b/e2e_tests/run_docker_e2e_db_tests.py @@ -138,7 +138,7 @@ def run_mongo_query(port, query=MONGO_QUERY): def create_tunnel(): logging.info('Creating SSHTunnelForwarder... (sshtunnel v%s, paramiko v%s)', sshtunnel.__version__, paramiko.__version__) - tunnel = SSHTunnelForwarder( + return SSHTunnelForwarder( SSH_SERVER_ADDRESS, ssh_username=SSH_SERVER_USERNAME, ssh_pkey=SSH_PKEY, @@ -148,7 +148,6 @@ def create_tunnel(): ], logger=logger, ) - return tunnel def start(tunnel): From 3d9d8e272f7eb636b59b1b9bb2c1663ea4a08ccf Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:18:49 -0500 Subject: [PATCH 22/46] Add explicit `return` statement --- e2e_tests/run_docker_e2e_db_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/e2e_tests/run_docker_e2e_db_tests.py b/e2e_tests/run_docker_e2e_db_tests.py index e5d441a..5ef688a 100644 --- a/e2e_tests/run_docker_e2e_db_tests.py +++ b/e2e_tests/run_docker_e2e_db_tests.py @@ -93,6 +93,7 @@ def wait_timeout(conn): raise psycopg2.OperationalError( "poll() returned %s from _wait_timeout function" % state ) + return None pg_conn = psycopg2.connect( host='127.0.0.1', From 10dd623a537094ad0c7a2ab1ecc75c32463fee2c Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:19:53 -0500 Subject: [PATCH 23/46] don't supply None as default to kwargs.get --- sshtunnel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index 7c6c67d..d640d01 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1753,7 +1753,7 @@ def do_something(port): """ # Attach a console handler to the logger or create one if not passed loglevel = kwargs.pop('debug_level', None) - logger = kwargs.get('logger', None) or create_logger(loglevel=loglevel) + logger = kwargs.get('logger') or create_logger(loglevel=loglevel) kwargs['logger'] = logger ssh_address_or_host = kwargs.pop('ssh_address_or_host', None) From 2278b32f09843966559d99459ae50a8922088f4f Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:20:45 -0500 Subject: [PATCH 24/46] Return condition directly as bool --- sshtunnel.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index d640d01..02edf5f 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1645,12 +1645,10 @@ def tunnel_bindings(self): @property def is_active(self): """ Return True if the underlying SSH transport is up """ - if ( + return bool( '_transport' in self.__dict__ and self._transport.is_active() - ): - return True - return False + ) def __exit__(self, *args): self.stop(force=True) From 17c4ea1c80a8867a3e15d7def2755e8c95f77d3d Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:21:08 -0500 Subject: [PATCH 25/46] Use `key in dict` instead of `key in dict.keys()` --- sshtunnel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index 02edf5f..6ea7a1b 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1263,7 +1263,7 @@ def get_keys( # noqa: C901 too complex if hasattr(paramiko, 'Ed25519Key'): paramiko_key_types['ed25519'] = paramiko.Ed25519Key for directory in host_pkey_directories: - for keytype in paramiko_key_types.keys(): + for keytype in paramiko_key_types: ssh_pkey_expanded = os.path.expanduser( os.path.join(directory, 'id_{}'.format(keytype)) ) From 7ce54f163cb01a449894c7332ffc6ba40ab0f76f Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:21:52 -0500 Subject: [PATCH 26/46] Remove unnecessary `True if ... else False` --- sshtunnel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index 6ea7a1b..010139d 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -822,7 +822,7 @@ def _read_ssh_config(ssh_host, proxycommand else None) if compression is None: compression = hostname_info.get('compression', '') - compression = True if compression.upper() == 'YES' else False + compression = compression.upper() == 'YES' except IOError: if logger: logger.warning( From 0b61e8f87094f7a3a6b82e6fa4f8b3ff1c5540c8 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:24:22 -0500 Subject: [PATCH 27/46] Don't extract value from dictionary without calling `.items()` --- sshtunnel.py | 4 ++-- tests/test_forwarder.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 010139d..dbd39e8 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1263,7 +1263,7 @@ def get_keys( # noqa: C901 too complex if hasattr(paramiko, 'Ed25519Key'): paramiko_key_types['ed25519'] = paramiko.Ed25519Key for directory in host_pkey_directories: - for keytype in paramiko_key_types: + for keytype, value in paramiko_key_types.items(): ssh_pkey_expanded = os.path.expanduser( os.path.join(directory, 'id_{}'.format(keytype)) ) @@ -1272,7 +1272,7 @@ def get_keys( # noqa: C901 too complex ssh_pkey = SSHTunnelForwarder.read_private_key_file( pkey_file=ssh_pkey_expanded, logger=logger, - key_type=paramiko_key_types[keytype] + key_type=value ) if ssh_pkey: keys.append(ssh_pkey) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 76bc9ec..5edde38 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1382,13 +1382,15 @@ def test_process_deprecations(self): 'ssh_address': '10.0.0.1', 'ssh_private_key': 'testrsa.key', 'raise_exception_if_any_forwarder_have_a_problem': True} - for item in kwargs: - self.assertEqual(kwargs[item], - sshtunnel.SSHTunnelForwarder._process_deprecated( - None, - item, - kwargs.copy() - )) + for item, value in kwargs.items(): + self.assertEqual( + value, + sshtunnel.SSHTunnelForwarder._process_deprecated( + None, + item, + kwargs.copy() + ) + ) # use both deprecated and not None new attribute should raise exception for item in kwargs: with self.assertRaises(ValueError): From 6a0ed8c85171ce6655d8590f1c0696a91a8883ff Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:26:09 -0500 Subject: [PATCH 28/46] Unused unpacked variables should assign to _ --- sshtunnel.py | 2 +- tests/test_forwarder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index dbd39e8..3edd56c 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -416,7 +416,7 @@ def __init__(self, *args, **kwargs): socketserver.TCPServer.__init__(self, *args, **kwargs) def handle_error(self, request, client_address): - (exc_class, exc, tb) = sys.exc_info() + (_, exc, _) = sys.exc_info() local_side = request.getsockname() remote_side = self.remote_address self.logger.error( diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 5edde38..877892a 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -322,7 +322,7 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): def _run_ssh_server(self): self.log.info('ssh-server Start') try: - self.socks, addr = self.ssockl.accept() + self.socks, _ = self.ssockl.accept() except socket.timeout: self.log.error('ssh-server connection timed out!') self.running_threads.remove('ssh-server') From 59bbf203f472464aac6a49116831ac99b01733bf Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 00:27:00 -0500 Subject: [PATCH 29/46] Unused `noqa` directive --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 45fe2ae..3499d54 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ # The project's main homepage. url=url, - download_url=ppa + version + '.zip', # noqa + download_url=ppa + version + '.zip', # Author details author='Pahaz White', From fe880d5245ba85dcbe81a1114747502a906e97bc Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 18:45:30 -0500 Subject: [PATCH 30/46] Replace ternary `if` expression with `or` operator --- sshtunnel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index 3edd56c..fa0a57f 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1690,7 +1690,7 @@ def __str__(self): self.ssh_proxy.cmd[1] if self.ssh_proxy else 'no', self.ssh_username, credentials, - self.ssh_host_key if self.ssh_host_key else 'not checked', + self.ssh_host_key or 'not checked', '' if self.is_alive else 'not ', 'disabled' if not self.set_keepalive else 'every {0} sec'.format(self.set_keepalive), From 99ea248124db089f93c15e800b2e9e2bcf8b6f2d Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 18:50:13 -0500 Subject: [PATCH 31/46] resolve rst docstring issues --- sshtunnel.py | 278 ++++++++++++++++++++++++---------------- tests/test_forwarder.py | 19 ++- 2 files changed, 178 insertions(+), 119 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index fa0a57f..5f495bc 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -96,20 +96,26 @@ def check_address(address): """ Check if the format of the address is correct + .. code-block:: python + + check_address(('127.0.0.1', 22)) + Arguments: - address (tuple): - (``str``, ``int``) representing an IP address and port, + address (tuple, str): + (``str``, ``int``) or ``str`` representing an IP address and port, respectively .. note:: - alternatively a local ``address`` can be a ``str`` when working - with UNIX domain sockets, if supported by the platform + a local ``address`` can be a ``str`` only when working + with UNIX domain sockets (not supported on all platforms) + Raises: ValueError: - raised when address has an incorrect format + When ``address`` has an incorrect format + + TypeError: + When ``address`` is not a tuple, string, or character buffer - Example: - >>> check_address(('127.0.0.1', 22)) """ if isinstance(address, tuple): check_host(address[0]) @@ -139,6 +145,10 @@ def check_addresses(address_list, is_remote=False): """ Check if the format of the addresses is correct + .. code-block:: python + + check_addresses([('127.0.0.1', 22), ('127.0.0.1', 2222)]) + Arguments: address_list (list[tuple]): Sequence of (``str``, ``int``) pairs, each representing an IP @@ -149,17 +159,13 @@ def check_addresses(address_list, is_remote=False): the list can be of type ``str``, representing a valid UNIX domain socket - is_remote (boolean): + is_remote (bool): Whether or not the address list + Raises: AssertionError: - raised when ``address_list`` contains an invalid element - ValueError: - raised when any address in the list has an incorrect format - - Example: + When ``address_list`` contains an invalid element - >>> check_addresses([('127.0.0.1', 22), ('127.0.0.1', 2222)]) """ assert all(isinstance(x, (tuple, string_types)) for x in address_list) if (is_remote and any(isinstance(x, string_types) for x in address_list)): @@ -213,32 +219,33 @@ def create_logger(logger=None, Attach or create a new logger and add a console handler if not present Arguments: - logger (Optional[logging.Logger]): - :class:`logging.Logger` instance; a new one is created if this - argument is empty + :py:class:`logging.Logger` instance; + a new one is created if this argument is empty loglevel (Optional[str or int]): - :class:`logging.Logger`'s level, either as a string (i.e. - ``ERROR``) or in numeric format (10 == ``DEBUG``) + :py:class:`logging.Logger`'s level, + either as a string (i.e. ``ERROR``) + or in numeric format (10 == ``DEBUG``) - .. note:: a value of 1 == ``TRACE`` enables Tracing mode + .. note:: + a value of 1 == ``TRACE`` enables Tracing mode - capture_warnings (boolean): + capture_warnings (bool): Enable/disable capturing the events logged by the warnings module into ``logger``'s handlers Default: True - .. note:: ignored in python 2.6 - - add_paramiko_handler (boolean): + add_paramiko_handler (bool): Whether or not add a console handler for ``paramiko.transport``'s logger if no handler present Default: True + Return: - :class:`logging.Logger` + :py:class:`logging.Logger` + """ logger = logger or logging.getLogger( 'sshtunnel.SSHTunnelForwarder' @@ -288,7 +295,7 @@ def generate_random_string(length): class BaseSSHTunnelForwarderError(Exception): - """ Exception raised by :class:`SSHTunnelForwarder` errors """ + """ Exception raised by :py:class:`SSHTunnelForwarder` errors """ def __init__(self, *args, **kwargs): self.value = kwargs.pop('value', args[0] if args else '') @@ -520,29 +527,28 @@ class SSHTunnelForwarder(object): """ **SSH tunnel class** - - Initialize a SSH tunnel to a remote host according to the input - arguments + - Initialize a SSH tunnel to a remote host according to the input + arguments - - Optionally: - + Read an SSH configuration file (typically ``~/.ssh/config``) - + Load keys from a running SSH agent (i.e. Pageant, GNOME Keyring) + - Optionally: - Raises: + + Read an SSH configuration file (typically ``~/.ssh/config``) + + Load keys from a running SSH agent (i.e. Pageant, GNOME Keyring) - :class:`.BaseSSHTunnelForwarderError`: + Raises: + :py:class:`.BaseSSHTunnelForwarderError`: raised by SSHTunnelForwarder class methods - :class:`.HandlerSSHTunnelForwarderError`: + :py:class:`.HandlerSSHTunnelForwarderError`: raised by tunnel forwarder threads .. note:: - Attributes ``mute_exceptions`` and - ``raise_exception_if_any_forwarder_have_a_problem`` - (deprecated) may be used to silence most exceptions raised - from this class + Attributes ``mute_exceptions`` and + ``raise_exception_if_any_forwarder_have_a_problem`` + (deprecated) may be used to silence most exceptions raised + from this class Keyword Arguments: - ssh_address_or_host (tuple or str): IP or hostname of ``REMOTE GATEWAY``. It may be a two-element tuple (``str``, ``int``) representing IP and port respectively, @@ -589,7 +595,7 @@ class SSHTunnelForwarder(object): ssh_pkey (str or paramiko.PKey): **Private** key file name (``str``) to obtain the public key - from or a **public** key (:class:`paramiko.pkey.PKey`) + from or a **public** key (:py:class:`paramiko.pkey.PKey`) ssh_private_key_password (str): Password for an encrypted ``ssh_pkey`` @@ -600,9 +606,9 @@ class SSHTunnelForwarder(object): ssh_proxy (socket-like object or tuple): Proxy where all SSH traffic will be passed through. - It might be for example a :class:`paramiko.proxy.ProxyCommand` + It might be for example a :py:class:`paramiko.proxy.ProxyCommand` instance. - See either the :class:`paramiko.transport.Transport`'s sock + See either the :py:class:`paramiko.transport.Transport`'s sock parameter documentation or ``ProxyCommand`` in ``ssh_config(5)`` for more information. @@ -614,11 +620,11 @@ class SSHTunnelForwarder(object): .. versionadded:: 0.0.5 - ssh_proxy_enabled (boolean): + ssh_proxy_enabled (bool): Enable/disable SSH proxy. If True and user's ``ssh_config_file`` contains a ``ProxyCommand`` directive that matches the specified ``ssh_address_or_host``, - a :class:`paramiko.proxy.ProxyCommand` object will be created where + a :py:class:`paramiko.proxy.ProxyCommand` object will be created where all SSH traffic will be passed through Default: ``True`` @@ -659,7 +665,7 @@ class SSHTunnelForwarder(object): .. versionadded:: 0.0.4 - allow_agent (boolean): + allow_agent (bool): Enable/disable load of keys from an SSH agent Default: ``True`` @@ -673,7 +679,7 @@ class SSHTunnelForwarder(object): .. versionadded:: 0.1.4 - compression (boolean): + compression (bool): Turn on/off transport compression. By default compression is disabled since it may negatively affect interactive sessions @@ -684,15 +690,15 @@ class SSHTunnelForwarder(object): logger (logging.Logger): logging instance for sshtunnel and paramiko - Default: :class:`logging.Logger` instance with a single - :class:`logging.StreamHandler` handler and - :const:`DEFAULT_LOGLEVEL` level + Default: :py:class:`logging.Logger` instance with a single + :py:class:`logging.StreamHandler` handler and + py:const: `DEFAULT_LOGLEVEL` level .. versionadded:: 0.0.3 - mute_exceptions (boolean): - Allow silencing :class:`BaseSSHTunnelForwarderError` or - :class:`HandlerSSHTunnelForwarderError` exceptions when enabled + mute_exceptions (bool): + Allow silencing :py:class:`BaseSSHTunnelForwarderError` or + :py:class:`HandlerSSHTunnelForwarderError` exceptions when enabled Default: ``False`` @@ -709,7 +715,7 @@ class SSHTunnelForwarder(object): .. versionadded:: 0.0.7 - threaded (boolean): + threaded (bool): Allow concurrent connections over a single tunnel Default: ``True`` @@ -731,13 +737,13 @@ class SSHTunnelForwarder(object): ssh_private_key (str or paramiko.PKey): Superseded by ``ssh_pkey``, which can represent either a **private** key file name (``str``) or a **public** key - (:class:`paramiko.pkey.PKey`) + (:py:class:`paramiko.pkey.PKey`) .. deprecated:: 0.0.8 - raise_exception_if_any_forwarder_have_a_problem (boolean): - Allow silencing :class:`BaseSSHTunnelForwarderError` or - :class:`HandlerSSHTunnelForwarderError` exceptions when set to + raise_exception_if_any_forwarder_have_a_problem (bool): + Allow silencing :py:class:`BaseSSHTunnelForwarderError` or + :py:class:`HandlerSSHTunnelForwarderError` exceptions when set to False Default: ``True`` @@ -760,14 +766,15 @@ class SSHTunnelForwarder(object): When :attr:`.skip_tunnel_checkup` is disabled or the local bind is a UNIX socket, the value will always be ``True`` - **Example**:: - - {('127.0.0.1', 55550): True, # this tunnel is up - ('127.0.0.1', 55551): False} # this one isn't + .. code-block:: python + :caption: where 55550 and 55551 are the local bind ports - where 55550 and 55551 are the local bind ports + { + ('127.0.0.1', 55550): True, # this tunnel is up + ('127.0.0.1', 55551): False # this one isn't + } - skip_tunnel_checkup (boolean): + skip_tunnel_checkup (bool): Disable tunnel checkup (default for backwards compatibility). .. versionadded:: 0.1.0 @@ -845,6 +852,11 @@ def _consolidate_binds(local_binds, remote_binds): """ Fill local_binds with defaults when no value/s were specified, leaving paramiko to decide in which local port the tunnel will be open + + Raises: + ValueError: + When there are more local bind addresses than remote addresses + """ count = len(remote_binds) - len(local_binds) if count < 0: @@ -865,10 +877,16 @@ def _consolidate_auth(ssh_password=None, logger=None): """ Get sure authentication information is in place. + ``ssh_pkey`` may be of classes: - - ``str`` - in this case it represents a private key file; public - key will be obtained from it - - ``paramiko.Pkey`` - it will be transparently added to loaded keys + + - If ``str``, it represents a private key file; + public key will be obtained from it + - If ``paramiko.Pkey``, it will be transparently added to loaded keys + + Raises: + ValueError: + When no password or public key are provided or available """ ssh_loaded_pkeys = SSHTunnelForwarder.get_keys( @@ -932,6 +950,12 @@ def _get_binds(bind_address, bind_addresses, is_remote=False): def _process_deprecated(attrib, deprecated_attrib, kwargs): """ Processes optional deprecate arguments + + Raises: + ValueError: + When a pre-deprecation arg AND its + replacement are both provided + """ if deprecated_attrib not in _DEPRECATIONS: msg = ( @@ -1079,15 +1103,17 @@ def local_is_up(self, target): Check if a tunnel is up (remote target's host is reachable on TCP target's port) + .. deprecated:: 0.1.0 + Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up` + Arguments: target (tuple): tuple of type (``str``, ``int``) indicating the listen IP address and port - Return: + + Returns: boolean - .. deprecated:: 0.1.0 - Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up` """ try: check_address(target) @@ -1218,8 +1244,9 @@ def get_agent_keys(logger=None): Arguments: logger (Optional[logging.Logger]) - Return: + Returns: list + """ paramiko_agent = paramiko.Agent() agent_keys = paramiko_agent.get_keys() @@ -1248,8 +1275,9 @@ def get_keys( # noqa: C901 too complex Default: False - Return: + Returns: list + """ keys = SSHTunnelForwarder.get_agent_keys(logger=logger) \ if allow_agent else [] @@ -1370,9 +1398,11 @@ def _stop_transport(self, force=False): def _connect_to_gateway(self): """ Open connection to SSH gateway - - First try with all keys loaded from an SSH agent (if allowed) - - Then with those passed directly or read from ~/.ssh/config - - As last resort, try with a provided password + + - First try with all keys loaded from an SSH agent (if allowed) + - Then with those passed directly or read from ~/.ssh/config + - As last resort, try with a provided password + """ for key in self.ssh_pkeys: self.logger.debug( @@ -1449,12 +1479,15 @@ def read_private_key_file(pkey_file, Arguments: pkey_file (str): File containing a private key (RSA, DSS or ECDSA) + Keyword Arguments: pkey_password (Optional[str]): Password to decrypt the private key logger (Optional[logging.Logger]) + Return: paramiko.Pkey + """ ssh_pkey = None key_types = (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey) @@ -1534,27 +1567,29 @@ def stop(self, force=False): Shut the tunnel down. By default we are always waiting until closing all connections. You can use `force=True` to force close connections - Keyword Arguments: - force (bool): - Force close current connections - - Default: False - - .. versionadded:: 0.2.2 - - .. note:: This **had** to be handled with care before ``0.1.0``: + .. note:: + This **had** to be handled with care before ``0.1.0`` - if a port redirection is opened - the destination is not reachable - we attempt a connection to that tunnel (``SYN`` is sent and acknowledged, then a ``FIN`` packet is sent and never acknowledged... weird) - - we try to shutdown: it will not succeed until ``FIN_WAIT_2`` and - ``CLOSE_WAIT`` time out. + - we try to shutdown: it will not succeed until ``FIN_WAIT_2`` + and ``CLOSE_WAIT`` time out. .. note:: - Handle these scenarios with :attr:`.tunnel_is_up`: if False, server - ``shutdown()`` will be skipped on that tunnel + Handle these scenarios with :attr:`.tunnel_is_up`: + if False, server ``shutdown()`` will be skipped on that tunnel + + Keyword Arguments: + force (bool): + Force close current connections + + Default: False + + .. versionadded:: 0.2.2 + """ self.logger.info('Closing all open connections...') opened_address_text = ', '.join( @@ -1708,7 +1743,32 @@ def __repr__(self): def open_tunnel(*args, **kwargs): """ - Open an SSH Tunnel, wrapper for :class:`SSHTunnelForwarder` + Open an SSH Tunnel, wrapper for :py:class:`SSHTunnelForwarder` + + .. note:: + A value of ``debug_level`` set to 1 == ``TRACE`` enables tracing mode + + .. note:: + See :py:class:`SSHTunnelForwarder` for keyword arguments + + .. code-block:: python + + from sshtunnel import open_tunnel + + with open_tunnel( + SERVER, + ssh_username=SSH_USER, + ssh_port=22, + ssh_password=SSH_PASSWORD, + remote_bind_address=(REMOTE_HOST, REMOTE_PORT), + local_bind_address=('', LOCAL_PORT) + ) as server: + + def do_something(port): + pass + + print("LOCAL PORTS:", server.local_bind_port) + do_something(server.local_bind_port) Arguments: destination (Optional[tuple]): @@ -1717,9 +1777,10 @@ def open_tunnel(*args, **kwargs): Keyword Arguments: debug_level (Optional[int or str]): - log level for :class:`logging.Logger` instance, i.e. ``DEBUG`` + log level for :py:class:`logging.Logger` instance, + i.e. ``DEBUG`` - skip_tunnel_checkup (boolean): + skip_tunnel_checkup (bool): Enable/disable the local side check and populate :attr:`~SSHTunnelForwarder.tunnel_is_up` @@ -1727,27 +1788,6 @@ def open_tunnel(*args, **kwargs): .. versionadded:: 0.1.0 - .. note:: - A value of ``debug_level`` set to 1 == ``TRACE`` enables tracing mode - .. note:: - See :class:`SSHTunnelForwarder` for keyword arguments - - **Example**:: - - from sshtunnel import open_tunnel - - with open_tunnel(SERVER, - ssh_username=SSH_USER, - ssh_port=22, - ssh_password=SSH_PASSWORD, - remote_bind_address=(REMOTE_HOST, REMOTE_PORT), - local_bind_address=('', LOCAL_PORT)) as server: - def do_something(port): - pass - - print("LOCAL PORTS:", server.local_bind_port) - - do_something(server.local_bind_port) """ # Attach a console handler to the logger or create one if not passed loglevel = kwargs.pop('debug_level', None) @@ -1783,8 +1823,20 @@ def do_something(port): def _bindlist(input_str): - """ Define type of data expected for remote and local bind address lists - Returns a tuple (ip_address, port) whose elements are (str, int) + """ + Define type of data expected for remote and local bind address lists + + Returns: + tuple + (ip_address, port) whose elements are (str, int) + + Raises: + ArgumentTypeError: + When tuple is not IP_ADDRESS:PORT + + AssertionError: + When IP_ADDRFESS and/or PORT are missing + """ try: ip_port = input_str.split(':') diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 877892a..ffb8c7a 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -36,12 +36,19 @@ def get_random_string(length=12): """ - >>> r = get_random_string(1) - >>> r in asciis - True - >>> r = get_random_string(2) - >>> [r[0] in asciis, r[1] in asciis] - [True, True] + + .. code-block:: python + :caption: output is True + + r = get_random_string(1) + r in asciis + + .. code-block:: python + :caption: output is [True, True] + + r = get_random_string(2) + [r[0] in asciis, r[1] in asciis] + """ ascii_lowercase = 'abcdefghijklmnopqrstuvwxyz' ascii_uppercase = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' From ce08d2d396294697c76501b929d599d33d5bbb7a Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 19:37:59 -0500 Subject: [PATCH 32/46] don't define all allowed characters in get_random_string --- tests/test_forwarder.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index ffb8c7a..ab19649 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -8,6 +8,7 @@ import select import shutil import socket +import string import sys import tempfile import threading @@ -50,10 +51,7 @@ def get_random_string(length=12): [r[0] in asciis, r[1] in asciis] """ - ascii_lowercase = 'abcdefghijklmnopqrstuvwxyz' - ascii_uppercase = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' - digits = '0123456789' - asciis = ascii_lowercase + ascii_uppercase + digits + asciis = string.ascii_lowercase + string.ascii_uppercase + string.digits return ''.join([random.choice(asciis) for _ in range(length)]) From cfabba4a015f21c0f99ab42d92806e0b4b4fe6e3 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 19:45:59 -0500 Subject: [PATCH 33/46] Replace aliased errors with `OSError` --- sshtunnel.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 5f495bc..646ad1f 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -374,7 +374,7 @@ def handle(self): src_addr=src_address, timeout=TUNNEL_TIMEOUT ) - except (paramiko.SSHException, EnvironmentError) as e: + except (OSError, paramiko.SSHException) as e: type_msg = 'ssh ' if isinstance(e, paramiko.SSHException) else '' exc_msg = 'open new channel {0}error: {1}'.format(type_msg, e) self.logger.log(TRACE_LEVEL, '%s %s', self.info, exc_msg) @@ -387,7 +387,7 @@ def handle(self): ) try: self._redirect(chan) - except socket.error: + except OSError: # Sometimes a RST is sent and a socket error is raised, treat this # exception. It was seen that a 3way FIN is processed later on, so # no need to make an ordered close of the connection here or raise @@ -830,7 +830,7 @@ def _read_ssh_config(ssh_host, if compression is None: compression = hostname_info.get('compression', '') compression = compression.upper() == 'YES' - except IOError: + except OSError: if logger: logger.warning( 'Could not read SSH configuration file: %s', @@ -1149,7 +1149,7 @@ def _check_tunnel(self, _srv): self.logger.debug( 'Tunnel to %s is DOWN', _srv.remote_address ) - except socket.error: + except OSError: self.logger.debug( 'Tunnel to %s is DOWN', _srv.remote_address ) @@ -1227,7 +1227,7 @@ def _make_ssh_forward_server(self, remote_address, local_bind_address): 'argument'.format(address_to_str(local_bind_address), address_to_str(remote_address)) ) - except IOError: + except OSError: self._raise( BaseSSHTunnelForwarderError, "Couldn't open tunnel {0} <> {1} might be in use or " @@ -1451,7 +1451,7 @@ def _create_tunnels(self): self.ssh_host ) return - except (paramiko.SSHException, socket.error) as e: + except (OSError, paramiko.SSHException) as e: self.logger.error( 'Could not connect to gateway %s:%s : %s', self.ssh_host, From bd6de01d86f4ba760aa91de7d556a2978f149ea5 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 19:47:41 -0500 Subject: [PATCH 34/46] Remove `object` inheritance --- sshtunnel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index 646ad1f..8ebecfa 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -523,7 +523,7 @@ class _ThreadingStreamForwardServer(socketserver.ThreadingMixIn, daemon_threads = _DAEMON -class SSHTunnelForwarder(object): +class SSHTunnelForwarder: """ **SSH tunnel class** From fea446484705d5c743360c560dcdcb3fc28752f1 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 19:49:36 -0500 Subject: [PATCH 35/46] remove unnecessary coding: utf-8 --- docs/conf.py | 1 - sshtunnel.py | 1 - 2 files changed, 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 64eb688..85cc6fb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # # sshtunnel documentation build configuration file, created by # sphinx-quickstart on Mon Feb 22 11:01:56 2016. diff --git a/sshtunnel.py b/sshtunnel.py index 8ebecfa..1754f32 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ *sshtunnel* - Initiate SSH tunnels via a remote gateway. From c0f278c48731b98b4dfe80fa2d5b6e6d3c60b162 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 19:53:56 -0500 Subject: [PATCH 36/46] shorten long line --- sshtunnel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 1754f32..100554e 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -623,8 +623,8 @@ class SSHTunnelForwarder: Enable/disable SSH proxy. If True and user's ``ssh_config_file`` contains a ``ProxyCommand`` directive that matches the specified ``ssh_address_or_host``, - a :py:class:`paramiko.proxy.ProxyCommand` object will be created where - all SSH traffic will be passed through + a :py:class:`paramiko.proxy.ProxyCommand` object will be created + where all SSH traffic will be passed through Default: ``True`` From 088bbd252b8b2666e2c6ef2c41514335d52aa67f Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 19:55:51 -0500 Subject: [PATCH 37/46] remove extraneous parentheses --- sshtunnel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index 100554e..4f509bf 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1592,7 +1592,7 @@ def stop(self, force=False): """ self.logger.info('Closing all open connections...') opened_address_text = ', '.join( - (address_to_str(k.local_address) for k in self._server_list) + address_to_str(k.local_address) for k in self._server_list ) or 'None' self.logger.debug('Listening tunnels: %s', opened_address_text) self._stop_transport(force=force) From f82315f1bfdaf36c95afd79eb47377b278ec1783 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 19:58:48 -0500 Subject: [PATCH 38/46] use dict comprehension instead of generator --- sshtunnel.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 4f509bf..686f37a 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1672,9 +1672,11 @@ def tunnel_bindings(self): """ Return a dictionary containing the active local<>remote tunnel_bindings """ - return dict((_server.remote_address, _server.local_address) for - _server in self._server_list if - self.tunnel_is_up[_server.local_address]) + return { + _server.remote_address: _server.local_address + for _server in self._server_list + if self.tunnel_is_up[_server.local_address] + } @property def is_active(self): From de5262fe7d8a92582ca84a80caf5f333a58bacc6 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 20:02:02 -0500 Subject: [PATCH 39/46] move return statement from try to else --- sshtunnel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index 686f37a..21b5a55 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1692,10 +1692,11 @@ def __exit__(self, *args): def __enter__(self): try: self.start() - return self except KeyboardInterrupt: self.__exit__() raise + else: + return self def __str__(self): credentials = { From 489d833be4eddb9f848d36eca28e9008ad6b4009 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 20:06:51 -0500 Subject: [PATCH 40/46] Use format specifiers instead of percent format --- docs/conf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 85cc6fb..b33942f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,8 @@ def _warn_node(self, msg, node): if not msg.startswith('nonlocal image URI found:'): - self._warnfunc(msg, '%s:%s' % get_source_line(node)) + line, col = get_source_line(node) + self._warnfunc(msg, "{0}:{1}".format(line, col)) sphinx.environment.BuildEnvironment.warn_node = _warn_node From ab88e064d3a9b2dfc332621fe39f2e1ce71d36b5 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 20:12:32 -0500 Subject: [PATCH 41/46] when checking for version, check for 3 before 2 as it's more common --- sshtunnel.py | 12 ++++++------ tests/test_forwarder.py | 11 ++++------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 21b5a55..adb771c 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -29,17 +29,17 @@ import paramiko -if sys.version_info[0] < 3: +if sys.version_info[0] >= 3: + import queue + import socketserver + string_types = str + input_ = input +else: import Queue as queue import SocketServer as socketserver string_types = basestring # noqa: F821 undefined name input_ = raw_input # noqa: F821 undefined name -else: - import queue - import socketserver - string_types = str - input_ = input __version__ = '0.4.0' diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index ab19649..1f686b8 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -22,15 +22,12 @@ import sshtunnel -if sys.version_info[0] == 2: - from cStringIO import StringIO - if sys.version_info < (2, 7): - import unittest2 as unittest - else: - import unittest -else: +if sys.version_info[0] >= 3: import unittest from io import StringIO +else: + from cStringIO import StringIO + import unittest # UTILS From e4bf03c382696699a4dddd4030a30261fdcd7a91 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 10:18:28 -0500 Subject: [PATCH 42/46] expand caught errors for Could not read SSH configuration file --- sshtunnel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index adb771c..97719b8 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -829,7 +829,7 @@ def _read_ssh_config(ssh_host, if compression is None: compression = hostname_info.get('compression', '') compression = compression.upper() == 'YES' - except OSError: + except (IOError, AssertionError, OSError): if logger: logger.warning( 'Could not read SSH configuration file: %s', From 4f74247dd511519f658bdca5dd959c0f86083788 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 12:26:00 -0500 Subject: [PATCH 43/46] Resolve SyntaxWarning: 'return' in a 'finally' block --- sshtunnel.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 97719b8..1edf9e5 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -838,13 +838,13 @@ def _read_ssh_config(ssh_host, except (AttributeError, TypeError): # ssh_config_file is None if logger: logger.info('Skipping loading of ssh configuration file') - finally: - return (ssh_host, - ssh_username or getpass.getuser(), - ssh_pkey, - int(ssh_port) if ssh_port else 22, # fallback value - ssh_proxy, - compression) + + return (ssh_host, + ssh_username or getpass.getuser(), + ssh_pkey, + int(ssh_port) if ssh_port else 22, # fallback value + ssh_proxy, + compression) @staticmethod def _consolidate_binds(local_binds, remote_binds): From 286109807834e65a465441c9b3ec10cfd0048445 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 16:31:58 -0500 Subject: [PATCH 44/46] reformat for 79 line length --- docs/conf.py | 136 ++--- e2e_tests/run_docker_e2e_db_tests.py | 48 +- e2e_tests/run_docker_e2e_hangs_tests.py | 4 +- pyproject.toml | 2 +- setup.py | 22 +- sshtunnel.py | 740 +++++++++++++----------- tests/test_forwarder.py | 655 ++++++++++++--------- 7 files changed, 877 insertions(+), 730 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index b33942f..e8e11b7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,19 +22,20 @@ def _warn_node(self, msg, node): if not msg.startswith('nonlocal image URI found:'): line, col = get_source_line(node) - self._warnfunc(msg, "{0}:{1}".format(line, col)) + self._warnfunc(msg, '{0}:{1}'.format(line, col)) + sphinx.environment.BuildEnvironment.warn_node = _warn_node # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -54,7 +55,7 @@ def _warn_node(self, msg, node): source_suffix = '.rst' # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' @@ -82,9 +83,9 @@ def _warn_node(self, msg, node): # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -92,27 +93,27 @@ def _warn_node(self, msg, node): # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -127,91 +128,91 @@ def _warn_node(self, msg, node): # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. html_show_sourcelink = False # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr' -#html_search_language = 'en' +# html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} +# html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. htmlhelp_basename = 'sshtunneldoc' @@ -219,59 +220,58 @@ def _warn_node(self, msg, node): # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', - -# Latex figure (float) alignment -#'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', + # Latex figure (float) alignment + #'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'sshtunnel.tex', 'sshtunnel Documentation', - 'Pahaz Blinov', 'manual'), + ( + master_doc, + 'sshtunnel.tex', + 'sshtunnel Documentation', + 'Pahaz Blinov', + 'manual', + ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'sshtunnel', 'sshtunnel Documentation', - [author], 1) -] +man_pages = [(master_doc, 'sshtunnel', 'sshtunnel Documentation', [author], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -280,22 +280,28 @@ def _warn_node(self, msg, node): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'sshtunnel', 'sshtunnel Documentation', - author, 'sshtunnel', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + 'sshtunnel', + 'sshtunnel Documentation', + author, + 'sshtunnel', + 'One line description of project.', + 'Miscellaneous', + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False intersphinx_mapping = { 'paramiko': ('http://docs.paramiko.org/en/latest', None), diff --git a/e2e_tests/run_docker_e2e_db_tests.py b/e2e_tests/run_docker_e2e_db_tests.py index 5ef688a..0b37a36 100644 --- a/e2e_tests/run_docker_e2e_db_tests.py +++ b/e2e_tests/run_docker_e2e_db_tests.py @@ -14,12 +14,16 @@ sshtunnel.DEFAULT_LOGLEVEL = 1 logging.basicConfig( - format='%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/%(lineno)04d@%(module)-10.9s| %(message)s', level=1) + format='%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/%(lineno)04d@%(module)-10.9s| %(message)s', + level=1, +) logger = logging.root SSH_SERVER_ADDRESS = ('127.0.0.1', 2223) SSH_SERVER_USERNAME = 'linuxserver' -SSH_PKEY = os.path.join(os.path.dirname(__file__), 'ssh-server-config', 'ssh_host_rsa_key') +SSH_PKEY = os.path.join( + os.path.dirname(__file__), 'ssh-server-config', 'ssh_host_rsa_key' +) SSH_SERVER_REMOTE_SIDE_ADDRESS_PG = ('10.5.0.5', 5432) SSH_SERVER_REMOTE_SIDE_ADDRESS_MYSQL = ('10.5.0.6', 3306) SSH_SERVER_REMOTE_SIDE_ADDRESS_MONGO = ('10.5.0.7', 27017) @@ -29,7 +33,8 @@ PG_PASSWORD = 'postgres' PG_QUERY = 'select version()' PG_EXPECT = literal_eval( - """('PostgreSQL 13.0 (Debian 13.0-1.pgdg100+1) on x86_64-pc-linux-gnu, compiled by gcc (Debian 8.3.0-6) 8.3.0, 64-bit',)""") + """('PostgreSQL 13.0 (Debian 13.0-1.pgdg100+1) on x86_64-pc-linux-gnu, compiled by gcc (Debian 8.3.0-6) 8.3.0, 64-bit',)""" +) MYSQL_DATABASE_NAME = 'main' MYSQL_USERNAME = 'mysql' @@ -42,7 +47,8 @@ MONGO_PASSWORD = 'mongo' MONGO_QUERY = lambda client, db: client.server_info() MONGO_EXPECT = literal_eval( - """{'version': '3.6.23', 'gitVersion': 'd352e6a4764659e0d0350ce77279de3c1f243e5c', 'modules': [], 'allocator': 'tcmalloc', 'javascriptEngine': 'mozjs', 'sysInfo': 'deprecated', 'versionArray': [3, 6, 23, 0], 'openssl': {'running': 'OpenSSL 1.0.2g 1 Mar 2016', 'compiled': 'OpenSSL 1.0.2g 1 Mar 2016'}, 'buildEnvironment': {'distmod': 'ubuntu1604', 'distarch': 'x86_64', 'cc': '/opt/mongodbtoolchain/v2/bin/gcc: gcc (GCC) 5.4.0', 'ccflags': '-fno-omit-frame-pointer -fno-strict-aliasing -ggdb -pthread -Wall -Wsign-compare -Wno-unknown-pragmas -Winvalid-pch -Werror -O2 -Wno-unused-local-typedefs -Wno-unused-function -Wno-deprecated-declarations -Wno-unused-but-set-variable -Wno-missing-braces -fstack-protector-strong -fno-builtin-memcmp', 'cxx': '/opt/mongodbtoolchain/v2/bin/g++: g++ (GCC) 5.4.0', 'cxxflags': '-Woverloaded-virtual -Wno-maybe-uninitialized -std=c++14', 'linkflags': '-pthread -Wl,-z,now -rdynamic -Wl,--fatal-warnings -fstack-protector-strong -fuse-ld=gold -Wl,--build-id -Wl,--hash-style=gnu -Wl,-z,noexecstack -Wl,--warn-execstack -Wl,-z,relro', 'target_arch': 'x86_64', 'target_os': 'linux'}, 'bits': 64, 'debug': False, 'maxBsonObjectSize': 16777216, 'storageEngines': ['devnull', 'ephemeralForTest', 'mmapv1', 'wiredTiger'], 'ok': 1.0}""") + """{'version': '3.6.23', 'gitVersion': 'd352e6a4764659e0d0350ce77279de3c1f243e5c', 'modules': [], 'allocator': 'tcmalloc', 'javascriptEngine': 'mozjs', 'sysInfo': 'deprecated', 'versionArray': [3, 6, 23, 0], 'openssl': {'running': 'OpenSSL 1.0.2g 1 Mar 2016', 'compiled': 'OpenSSL 1.0.2g 1 Mar 2016'}, 'buildEnvironment': {'distmod': 'ubuntu1604', 'distarch': 'x86_64', 'cc': '/opt/mongodbtoolchain/v2/bin/gcc: gcc (GCC) 5.4.0', 'ccflags': '-fno-omit-frame-pointer -fno-strict-aliasing -ggdb -pthread -Wall -Wsign-compare -Wno-unknown-pragmas -Winvalid-pch -Werror -O2 -Wno-unused-local-typedefs -Wno-unused-function -Wno-deprecated-declarations -Wno-unused-but-set-variable -Wno-missing-braces -fstack-protector-strong -fno-builtin-memcmp', 'cxx': '/opt/mongodbtoolchain/v2/bin/g++: g++ (GCC) 5.4.0', 'cxxflags': '-Woverloaded-virtual -Wno-maybe-uninitialized -std=c++14', 'linkflags': '-pthread -Wl,-z,now -rdynamic -Wl,--fatal-warnings -fstack-protector-strong -fuse-ld=gold -Wl,--build-id -Wl,--hash-style=gnu -Wl,-z,noexecstack -Wl,--warn-execstack -Wl,-z,relro', 'target_arch': 'x86_64', 'target_os': 'linux'}, 'bits': 64, 'debug': False, 'maxBsonObjectSize': 16777216, 'storageEngines': ['devnull', 'ephemeralForTest', 'mmapv1', 'wiredTiger'], 'ok': 1.0}""" +) def run_postgres_query(port, query=PG_QUERY): @@ -64,7 +70,8 @@ def wait(conn): select.select([conn.fileno()], [], []) else: raise psycopg2.OperationalError( - "poll() returned %s from _wait function" % state) + 'poll() returned %s from _wait function' % state + ) def wait_timeout(conn): while 1: @@ -91,7 +98,7 @@ def wait_timeout(conn): return ASYNC_READ_TIMEOUT else: raise psycopg2.OperationalError( - "poll() returned %s from _wait_timeout function" % state + 'poll() returned %s from _wait_timeout function' % state ) return None @@ -103,7 +110,7 @@ def wait_timeout(conn): user=PG_USERNAME, password=PG_PASSWORD, sslmode='disable', - async_=1 + async_=1, ) wait(pg_conn) cur = pg_conn.cursor() @@ -116,6 +123,7 @@ def wait_timeout(conn): def run_mysql_query(port, query=MYSQL_QUERY): import pymysql + conn = pymysql.connect( host='127.0.0.1', port=port, @@ -123,7 +131,8 @@ def run_mysql_query(port, query=MYSQL_QUERY): password=MYSQL_PASSWORD, database=MYSQL_DATABASE_NAME, connect_timeout=5, - read_timeout=5) + read_timeout=5, + ) cursor = conn.cursor() cursor.execute(query) return cursor.fetchall() @@ -131,20 +140,25 @@ def run_mysql_query(port, query=MYSQL_QUERY): def run_mongo_query(port, query=MONGO_QUERY): import pymongo + client = pymongo.MongoClient('127.0.0.1', port) db = client[MONGO_DATABASE_NAME] return query(client, db) def create_tunnel(): - logging.info('Creating SSHTunnelForwarder... (sshtunnel v%s, paramiko v%s)', - sshtunnel.__version__, paramiko.__version__) + logging.info( + 'Creating SSHTunnelForwarder... (sshtunnel v%s, paramiko v%s)', + sshtunnel.__version__, + paramiko.__version__, + ) return SSHTunnelForwarder( SSH_SERVER_ADDRESS, ssh_username=SSH_SERVER_USERNAME, ssh_pkey=SSH_PKEY, remote_bind_addresses=[ - SSH_SERVER_REMOTE_SIDE_ADDRESS_PG, SSH_SERVER_REMOTE_SIDE_ADDRESS_MYSQL, + SSH_SERVER_REMOTE_SIDE_ADDRESS_PG, + SSH_SERVER_REMOTE_SIDE_ADDRESS_MYSQL, SSH_SERVER_REMOTE_SIDE_ADDRESS_MONGO, ], logger=logger, @@ -222,14 +236,16 @@ def show_threading_state_if_required(): if len(current_threads) > 1: logging.warning('[2] STACK INFO') - code = ["\n\n*** STACKTRACE - START ***\n"] + code = ['\n\n*** STACKTRACE - START ***\n'] for threadId, stack in sys._current_frames().items(): - code.append("\n# ThreadID: %s" % threadId) + code.append('\n# ThreadID: %s' % threadId) for filename, lineno, name, line in traceback.extract_stack(stack): - code.append('File: "%s", line %d, in %s' % (filename, lineno, name)) + code.append( + 'File: "%s", line %d, in %s' % (filename, lineno, name) + ) if line: - code.append(" %s" % (line.strip())) - code.append("\n*** STACKTRACE - END ***\n\n") + code.append(' %s' % (line.strip())) + code.append('\n*** STACKTRACE - END ***\n\n') logging.info('\n'.join(code)) diff --git a/e2e_tests/run_docker_e2e_hangs_tests.py b/e2e_tests/run_docker_e2e_hangs_tests.py index 7abd491..08b4648 100644 --- a/e2e_tests/run_docker_e2e_hangs_tests.py +++ b/e2e_tests/run_docker_e2e_hangs_tests.py @@ -4,7 +4,9 @@ import sshtunnel if __name__ == '__main__': - path = os.path.join(os.path.dirname(__file__), 'run_docker_e2e_db_tests.py') + path = os.path.join( + os.path.dirname(__file__), 'run_docker_e2e_db_tests.py' + ) with open(path) as f: exec(f.read()) logging.warning('RUN') diff --git a/pyproject.toml b/pyproject.toml index b0471b7..864b334 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta:__legacy__" \ No newline at end of file +build-backend = "setuptools.build_meta:__legacy__" diff --git a/setup.py b/setup.py index 3499d54..ef40ce6 100644 --- a/setup.py +++ b/setup.py @@ -29,33 +29,27 @@ with open(path.join(here, name + '.py'), encoding='utf-8') as f: data = f.read() version = literal_eval( - re.search("__version__[ ]*=[ ]*([^\r\n]+)", data).group(1) + re.search('__version__[ ]*=[ ]*([^\r\n]+)', data).group(1) ) setup( name=name, - # Versions should comply with PEP440. For a discussion on single-sourcing # the version across setup.py and the project code, see # https://packaging.python.org/en/latest/single_source_version.html version=version, - description=description, long_description='\n'.join((long_description, documentation, changelog)), long_description_content_type='text/x-rst', - # The project's main homepage. url=url, download_url=ppa + version + '.zip', - # Author details author='Pahaz White', author_email='pahaz.white@gmail.com', - # Choose your license license='MIT', - # See https://pypi.python.org/pypi?%3Aaction=list_classifiers classifiers=[ # How mature is this project? Common values are @@ -63,14 +57,11 @@ # 4 - Beta # 5 - Production/Stable 'Development Status :: 3 - Alpha', - # Indicate who your project is intended for 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', - # Pick your license as you wish (should match "license" above) 'License :: OSI Approved :: MIT License', - # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. 'Programming Language :: Python :: 2', @@ -82,20 +73,15 @@ 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', ], - platforms=['unix', 'macos', 'windows'], - # What does your project relate to? keywords='ssh tunnel paramiko proxy tcp-forward', - # You can just specify the packages manually here if your project is # simple. Or you can use find_packages(). # packages=find_packages(exclude=['contrib', 'docs', 'tests']), - # Alternatively, if you want to distribute just a my_module.py, uncomment # this: - py_modules=["sshtunnel"], - + py_modules=['sshtunnel'], # List run-time dependencies here. These will be installed by pip when # your project is installed. For an analysis of "install_requires" vs pip's # requirements files see: @@ -103,7 +89,6 @@ install_requires=[ 'paramiko>=2.7.2', ], - # List additional groups of dependencies here (e.g. development # dependencies). You can install these using the following syntax, # for example: @@ -121,14 +106,12 @@ 'sphinxcontrib-napoleon', ], }, - # If there are data files included in your packages that need to be # installed, specify them here. If using Python 2.6 or less, then these # have to be included in MANIFEST.in as well. package_data={ 'tests': ['testrsa.key'], }, - # To provide executable scripts, use entry points in preference to the # "scripts" keyword. Entry points provide cross-platform support and allow # pip to create the appropriate form of executable for the target platform. @@ -137,5 +120,4 @@ 'sshtunnel=sshtunnel:_cli_main', ] }, - ) diff --git a/sshtunnel.py b/sshtunnel.py index 1edf9e5..0d69c4c 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -32,6 +32,7 @@ if sys.version_info[0] >= 3: import queue import socketserver + string_types = str input_ = input else: @@ -57,7 +58,7 @@ 'ssh_address': 'ssh_address_or_host', 'ssh_host': 'ssh_address_or_host', 'ssh_private_key': 'ssh_pkey', - 'raise_exception_if_any_forwarder_have_a_problem': 'mute_exceptions' + 'raise_exception_if_any_forwarder_have_a_problem': 'mute_exceptions', } # logging @@ -66,8 +67,11 @@ logging.addLevelName(TRACE_LEVEL, 'TRACE') DEFAULT_SSH_DIRECTORY = '~/.ssh' -_StreamServer = socketserver.UnixStreamServer if os.name == 'posix' \ +_StreamServer = ( + socketserver.UnixStreamServer + if os.name == 'posix' else socketserver.TCPServer +) #: Path of optional ssh configuration file DEFAULT_SSH_DIRECTORY = '~/.ssh' @@ -127,15 +131,15 @@ def check_address(address): os.path.exists(address) or os.access(os.path.dirname(address), os.W_OK) ): - msg = ( - 'ADDRESS not a valid socket domain socket ({0})' - .format(address) + msg = 'ADDRESS not a valid socket domain socket ({0})'.format( + address ) raise ValueError(msg) else: msg = ( - 'ADDRESS is not a tuple, string, or character buffer ({0})' - .format(type(address).__name__) + 'ADDRESS is not a tuple, string, or character buffer ({0})'.format( + type(address).__name__ + ) ) raise TypeError(msg) @@ -167,7 +171,7 @@ def check_addresses(address_list, is_remote=False): """ assert all(isinstance(x, (tuple, string_types)) for x in address_list) - if (is_remote and any(isinstance(x, string_types) for x in address_list)): + if is_remote and any(isinstance(x, string_types) for x in address_list): msg = 'UNIX domain sockets not allowed for remote addresses' raise AssertionError(msg) @@ -187,9 +191,9 @@ def _add_handler(logger, handler=None, loglevel=None): ) handler.setFormatter(logging.Formatter(_fmt)) else: - handler.setFormatter(logging.Formatter( - '%(asctime)s| %(levelname)-8s| %(message)s' - )) + handler.setFormatter( + logging.Formatter('%(asctime)s| %(levelname)-8s| %(message)s') + ) logger.addHandler(handler) @@ -204,16 +208,20 @@ def _check_paramiko_handlers(logger=None): else: console_handler = logging.StreamHandler() console_handler.setFormatter( - logging.Formatter('%(asctime)s | %(levelname)-8s| PARAMIKO: ' - '%(lineno)03d@%(module)-10s| %(message)s') + logging.Formatter( + '%(asctime)s | %(levelname)-8s| PARAMIKO: ' + '%(lineno)03d@%(module)-10s| %(message)s' + ) ) paramiko_logger.addHandler(console_handler) -def create_logger(logger=None, - loglevel=None, - capture_warnings=True, - add_paramiko_handler=True): +def create_logger( + logger=None, + loglevel=None, + capture_warnings=True, + add_paramiko_handler=True, +): """ Attach or create a new logger and add a console handler if not present @@ -246,15 +254,15 @@ def create_logger(logger=None, :py:class:`logging.Logger` """ - logger = logger or logging.getLogger( - 'sshtunnel.SSHTunnelForwarder' - ) + logger = logger or logging.getLogger('sshtunnel.SSHTunnelForwarder') if not any(isinstance(x, logging.Handler) for x in logger.handlers): logger.setLevel(loglevel or DEFAULT_LOGLEVEL) console_handler = logging.StreamHandler() - _add_handler(logger, - handler=console_handler, - loglevel=loglevel or DEFAULT_LOGLEVEL) + _add_handler( + logger, + handler=console_handler, + loglevel=loglevel or DEFAULT_LOGLEVEL, + ) if loglevel: # override if loglevel was set logger.setLevel(loglevel) for handler in logger.handlers: @@ -277,15 +285,17 @@ def address_to_str(address): def _remove_none_values(dictionary): - """ Remove dictionary keys whose value is None """ - return list(map(dictionary.pop, - [i for i in dictionary if dictionary[i] is None])) + """Remove dictionary keys whose value is None""" + return list( + map(dictionary.pop, [i for i in dictionary if dictionary[i] is None]) + ) def generate_random_string(length): letters = string.ascii_letters + string.digits return ''.join(random.choice(letters) for _ in range(length)) + ######################## # # # Errors # @@ -294,7 +304,7 @@ def generate_random_string(length): class BaseSSHTunnelForwarderError(Exception): - """ Exception raised by :py:class:`SSHTunnelForwarder` errors """ + """Exception raised by :py:class:`SSHTunnelForwarder` errors""" def __init__(self, *args, **kwargs): self.value = kwargs.pop('value', args[0] if args else '') @@ -304,7 +314,7 @@ def __str__(self): class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError): - """ Exception for Tunnel forwarder errors """ + """Exception for Tunnel forwarder errors""" ######################## @@ -315,7 +325,8 @@ class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError): class _ForwardHandler(socketserver.BaseRequestHandler): - """ Base handler for tunnel connections """ + """Base handler for tunnel connections""" + remote_address = None ssh_transport = None logger = None @@ -329,7 +340,8 @@ def _redirect(self, chan): if not data: self.logger.log( TRACE_LEVEL, - '>>> OUT %s recv empty data >>>', self.info + '>>> OUT %s recv empty data >>>', + self.info, ) break if self.logger.isEnabledFor(TRACE_LEVEL): @@ -338,14 +350,15 @@ def _redirect(self, chan): '>>> OUT %s send to %s: %s >>>', self.info, self.remote_address, - hexlify(data) + hexlify(data), ) chan.sendall(data) if chan in rqst: # else if not chan.recv_ready(): self.logger.log( TRACE_LEVEL, - '<<< IN %s recv is not ready <<<', self.info + '<<< IN %s recv is not ready <<<', + self.info, ) break data = chan.recv(16384) @@ -353,15 +366,16 @@ def _redirect(self, chan): hex_data = hexlify(data) self.logger.log( TRACE_LEVEL, - '<<< IN %s recv: %s <<<', self.info, hex_data + '<<< IN %s recv: %s <<<', + self.info, + hex_data, ) self.request.sendall(data) def handle(self): uid = generate_random_string(5) self.info = '#{0} <-- {1}'.format( - uid, self.client_address - or self.server.local_address + uid, self.client_address or self.server.local_address ) src_address = self.request.getpeername() if not isinstance(src_address, tuple): @@ -371,7 +385,7 @@ def handle(self): kind='direct-tcpip', dest_addr=self.remote_address, src_addr=src_address, - timeout=TUNNEL_TIMEOUT + timeout=TUNNEL_TIMEOUT, ) except (OSError, paramiko.SSHException) as e: type_msg = 'ssh ' if isinstance(e, paramiko.SSHException) else '' @@ -379,11 +393,7 @@ def handle(self): self.logger.log(TRACE_LEVEL, '%s %s', self.info, exc_msg) raise HandlerSSHTunnelForwarderError(exc_msg) - self.logger.log( - TRACE_LEVEL, - '%s connected', - self.info - ) + self.logger.log(TRACE_LEVEL, '%s connected', self.info) try: self._redirect(chan) except OSError: @@ -393,26 +403,18 @@ def handle(self): # the exception beyond this point... self.logger.log(TRACE_LEVEL, '%s sending RST', self.info) except Exception as e: - self.logger.log( - TRACE_LEVEL, - '%s error: %s', - self.info, - repr(e) - ) + self.logger.log(TRACE_LEVEL, '%s error: %s', self.info, repr(e)) finally: chan.close() self.request.close() - self.logger.log( - TRACE_LEVEL, - '%s connection closed.', - self.info - ) + self.logger.log(TRACE_LEVEL, '%s connection closed.', self.info) class _ForwardServer(socketserver.TCPServer): # Not Threading """ Non-threading version of the forward server """ + allow_reuse_address = True # faster rebinding def __init__(self, *args, **kwargs): @@ -432,7 +434,7 @@ def handle_error(self, request, client_address): 'to remote', remote_side, 'side of the tunnel', - exc + exc, ) try: self.tunnel_ok.put(item=False, block=False, timeout=0.1) @@ -471,6 +473,7 @@ class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer): """ Allow concurrent connections to each tunnel """ + # If True, cleanly stop threads created by ThreadingMixIn when quitting # This value is overrides by SSHTunnelForwarder.daemon_forward_servers daemon_threads = _DAEMON @@ -512,11 +515,13 @@ def remote_port(self): return self.RequestHandlerClass.remote_address[1] -class _ThreadingStreamForwardServer(socketserver.ThreadingMixIn, - _StreamForwardServer): +class _ThreadingStreamForwardServer( + socketserver.ThreadingMixIn, _StreamForwardServer +): """ Allow concurrent connections to each tunnel """ + # If True, cleanly stop threads created by ThreadingMixIn when quitting # This value is overrides by SSHTunnelForwarder.daemon_forward_servers daemon_threads = _DAEMON @@ -770,7 +775,7 @@ class SSHTunnelForwarder: { ('127.0.0.1', 55550): True, # this tunnel is up - ('127.0.0.1', 55551): False # this one isn't + ('127.0.0.1', 55551): False, # this one isn't } skip_tunnel_checkup (bool): @@ -779,6 +784,7 @@ class SSHTunnelForwarder: .. versionadded:: 0.1.0 """ + skip_tunnel_checkup = True # This option affects the `ForwardServer` and all his threads daemon_forward_servers = _DAEMON #: flag tunnel threads in daemon mode @@ -786,14 +792,16 @@ class SSHTunnelForwarder: daemon_transport = _DAEMON #: flag SSH transport thread in daemon mode @staticmethod - def _read_ssh_config(ssh_host, - ssh_config_file, - ssh_username=None, - ssh_pkey=None, - ssh_port=None, - ssh_proxy=None, - compression=None, - logger=None): + def _read_ssh_config( + ssh_host, + ssh_config_file, + ssh_username=None, + ssh_pkey=None, + ssh_port=None, + ssh_proxy=None, + compression=None, + logger=None, + ): """ Read ssh_config_file and tries to look for user (ssh_username), identityfile (ssh_pkey), port (ssh_port) and proxycommand @@ -812,20 +820,15 @@ def _read_ssh_config(ssh_host, hostname_info = ssh_config.lookup(ssh_host) # gather settings for user, port and identity file # last resort: use the 'login name' of the user - ssh_username = ( - ssh_username - or hostname_info.get('user') - ) - ssh_pkey = ( - ssh_pkey - or hostname_info.get('identityfile', [None])[0] - ) + ssh_username = ssh_username or hostname_info.get('user') + ssh_pkey = ssh_pkey or hostname_info.get('identityfile', [None])[0] ssh_host = hostname_info.get('hostname') ssh_port = ssh_port or hostname_info.get('port') proxycommand = hostname_info.get('proxycommand') - ssh_proxy = ssh_proxy or (paramiko.ProxyCommand(proxycommand) if - proxycommand else None) + ssh_proxy = ssh_proxy or ( + paramiko.ProxyCommand(proxycommand) if proxycommand else None + ) if compression is None: compression = hostname_info.get('compression', '') compression = compression.upper() == 'YES' @@ -833,18 +836,20 @@ def _read_ssh_config(ssh_host, if logger: logger.warning( 'Could not read SSH configuration file: %s', - ssh_config_file + ssh_config_file, ) except (AttributeError, TypeError): # ssh_config_file is None if logger: logger.info('Skipping loading of ssh configuration file') - return (ssh_host, - ssh_username or getpass.getuser(), - ssh_pkey, - int(ssh_port) if ssh_port else 22, # fallback value - ssh_proxy, - compression) + return ( + ssh_host, + ssh_username or getpass.getuser(), + ssh_pkey, + int(ssh_port) if ssh_port else 22, # fallback value + ssh_proxy, + compression, + ) @staticmethod def _consolidate_binds(local_binds, remote_binds): @@ -868,12 +873,14 @@ def _consolidate_binds(local_binds, remote_binds): return local_binds @staticmethod - def _consolidate_auth(ssh_password=None, - ssh_pkey=None, - ssh_pkey_password=None, - allow_agent=True, - host_pkey_directories=None, - logger=None): + def _consolidate_auth( + ssh_password=None, + ssh_pkey=None, + ssh_pkey_password=None, + allow_agent=True, + host_pkey_directories=None, + logger=None, + ): """ Get sure authentication information is in place. @@ -891,7 +898,7 @@ def _consolidate_auth(ssh_password=None, ssh_loaded_pkeys = SSHTunnelForwarder.get_keys( logger=logger, host_pkey_directories=host_pkey_directories, - allow_agent=allow_agent + allow_agent=allow_agent, ) if isinstance(ssh_pkey, string_types): @@ -900,13 +907,10 @@ def _consolidate_auth(ssh_password=None, ssh_pkey = SSHTunnelForwarder.read_private_key_file( pkey_file=ssh_pkey_expanded, pkey_password=ssh_pkey_password or ssh_password, - logger=logger + logger=logger, ) elif logger: - logger.warning( - 'Private key file not found: %s', - ssh_pkey - ) + logger.warning('Private key file not found: %s', ssh_pkey) if isinstance(ssh_pkey, paramiko.pkey.PKey): ssh_loaded_pkeys.insert(0, ssh_pkey) @@ -922,9 +926,9 @@ def _get_binds(bind_address, bind_addresses, is_remote=False): if not bind_address and not bind_addresses: if is_remote: msg = ( - "No {0} bind addresses specified. Use " + 'No {0} bind addresses specified. Use ' "'{0}_bind_address' or '{0}_bind_addresses'" - " argument".format(addr_kind) + ' argument'.format(addr_kind) ) raise ValueError(msg) return [] @@ -932,14 +936,14 @@ def _get_binds(bind_address, bind_addresses, is_remote=False): msg = ( "You can't use both '{0}_bind_address' and " "'{0}_bind_addresses' arguments. Use one of " - "them.".format(addr_kind) + 'them.'.format(addr_kind) ) raise ValueError(msg) if bind_address: bind_addresses = [bind_address] if not is_remote: # Add random port if missing in local bind - for (i, local_bind) in enumerate(bind_addresses): + for i, local_bind in enumerate(bind_addresses): if isinstance(local_bind, tuple) and len(local_bind) == 1: bind_addresses[i] = (local_bind[0], 0) check_addresses(bind_addresses, is_remote) @@ -957,23 +961,22 @@ def _process_deprecated(attrib, deprecated_attrib, kwargs): """ if deprecated_attrib not in _DEPRECATIONS: - msg = ( - '{0} not included in deprecations list' - .format(deprecated_attrib) + msg = '{0} not included in deprecations list'.format( + deprecated_attrib ) raise ValueError(msg) if deprecated_attrib in kwargs: - warnings.warn("'{0}' is DEPRECATED use '{1}' instead" - .format(deprecated_attrib, - _DEPRECATIONS[deprecated_attrib]), - DeprecationWarning) + warnings.warn( + "'{0}' is DEPRECATED use '{1}' instead".format( + deprecated_attrib, _DEPRECATIONS[deprecated_attrib] + ), + DeprecationWarning, + ) if attrib: msg = ( "You can't use both '{0}' and '{1}'. " - "Please only use one of them" - .format( - deprecated_attrib, - _DEPRECATIONS[deprecated_attrib] + 'Please only use one of them'.format( + deprecated_attrib, _DEPRECATIONS[deprecated_attrib] ) ) raise ValueError(msg) @@ -981,29 +984,29 @@ def _process_deprecated(attrib, deprecated_attrib, kwargs): return attrib def __init__( - self, - ssh_address_or_host=None, - ssh_config_file=SSH_CONFIG_FILE, - ssh_host_key=None, - ssh_password=None, - ssh_pkey=None, - ssh_private_key_password=None, - ssh_proxy=None, - ssh_proxy_enabled=True, - ssh_username=None, - local_bind_address=None, - local_bind_addresses=None, - logger=None, - mute_exceptions=False, - remote_bind_address=None, - remote_bind_addresses=None, - set_keepalive=5.0, - threaded=True, # old version False - compression=None, - allow_agent=True, # look for keys from an SSH agent - host_pkey_directories=None, # look for keys in ~/.ssh - *args, - **kwargs # for backwards compatibility + self, + ssh_address_or_host=None, + ssh_config_file=SSH_CONFIG_FILE, + ssh_host_key=None, + ssh_password=None, + ssh_pkey=None, + ssh_private_key_password=None, + ssh_proxy=None, + ssh_proxy_enabled=True, + ssh_username=None, + local_bind_address=None, + local_bind_addresses=None, + logger=None, + mute_exceptions=False, + remote_bind_address=None, + remote_bind_addresses=None, + set_keepalive=5.0, + threaded=True, # old version False + compression=None, + allow_agent=True, # look for keys from an SSH agent + host_pkey_directories=None, # look for keys in ~/.ssh + *args, + **kwargs, # for backwards compatibility ): self.logger = logger or create_logger() @@ -1015,18 +1018,20 @@ def __init__( self.is_alive = False # Check if deprecated arguments ssh_address or ssh_host were used for deprecated_argument in ['ssh_address', 'ssh_host']: - ssh_address_or_host = self._process_deprecated(ssh_address_or_host, - deprecated_argument, - kwargs) + ssh_address_or_host = self._process_deprecated( + ssh_address_or_host, deprecated_argument, kwargs + ) # other deprecated arguments - ssh_pkey = self._process_deprecated(ssh_pkey, - 'ssh_private_key', - kwargs) + ssh_pkey = self._process_deprecated( + ssh_pkey, 'ssh_private_key', kwargs + ) - self._raise_fwd_exc = self._process_deprecated( - None, - 'raise_exception_if_any_forwarder_have_a_problem', - kwargs) or not mute_exceptions + self._raise_fwd_exc = ( + self._process_deprecated( + None, 'raise_exception_if_any_forwarder_have_a_problem', kwargs + ) + or not mute_exceptions + ) if isinstance(ssh_address_or_host, tuple): check_address(ssh_address_or_host) @@ -1040,29 +1045,33 @@ def __init__( raise ValueError(msg) # remote binds - self._remote_binds = self._get_binds(remote_bind_address, - remote_bind_addresses, - is_remote=True) + self._remote_binds = self._get_binds( + remote_bind_address, remote_bind_addresses, is_remote=True + ) # local binds - self._local_binds = self._get_binds(local_bind_address, - local_bind_addresses) - self._local_binds = self._consolidate_binds(self._local_binds, - self._remote_binds) - - (self.ssh_host, - self.ssh_username, - ssh_pkey, # still needs to go through _consolidate_auth - self.ssh_port, - self.ssh_proxy, - self.compression) = self._read_ssh_config( - ssh_host, - ssh_config_file, - ssh_username, - ssh_pkey, - ssh_port, - ssh_proxy if ssh_proxy_enabled else None, - compression, - self.logger + self._local_binds = self._get_binds( + local_bind_address, local_bind_addresses + ) + self._local_binds = self._consolidate_binds( + self._local_binds, self._remote_binds + ) + + ( + self.ssh_host, + self.ssh_username, + ssh_pkey, # still needs to go through _consolidate_auth + self.ssh_port, + self.ssh_proxy, + self.compression, + ) = self._read_ssh_config( + ssh_host, + ssh_config_file, + ssh_username, + ssh_pkey, + ssh_port, + ssh_proxy if ssh_proxy_enabled else None, + compression, + self.logger, ) (self.ssh_password, self.ssh_pkeys) = self._consolidate_auth( @@ -1071,7 +1080,7 @@ def __init__( ssh_pkey_password=ssh_private_key_password, allow_agent=allow_agent, host_pkey_directories=host_pkey_directories, - logger=self.logger + logger=self.logger, ) check_host(self.ssh_host) @@ -1081,20 +1090,18 @@ def __init__( "Connecting to gateway: %s:%s as user '%s'", self.ssh_host, self.ssh_port, - self.ssh_username + self.ssh_username, ) - self.logger.debug( - 'Concurrent connections allowed: %s', - self._threaded - ) + self.logger.debug('Concurrent connections allowed: %s', self._threaded) def __del__(self): if self.is_active or self.is_alive: self.logger.warning( "It looks like you didn't call the .stop() before " - "the SSHTunnelForwarder obj was collected by " - "the garbage collector! Running .stop(force=True)") + 'the SSHTunnelForwarder obj was collected by ' + 'the garbage collector! Running .stop(force=True)' + ) self.stop(force=True) def local_is_up(self, target): @@ -1117,17 +1124,19 @@ def local_is_up(self, target): try: check_address(target) except ValueError: - self.logger.warning('Target must be a tuple (IP, port), where IP ' - 'is a string (i.e. "192.168.0.1") and port is ' - 'an integer (i.e. 40000). Alternatively ' - 'target can be a valid UNIX domain socket.') + self.logger.warning( + 'Target must be a tuple (IP, port), where IP ' + 'is a string (i.e. "192.168.0.1") and port is ' + 'an integer (i.e. 40000). Alternatively ' + 'target can be a valid UNIX domain socket.' + ) return False self.check_tunnels() return self.tunnel_is_up.get(target, True) def _check_tunnel(self, _srv): - """ Check if tunnel is already established """ + """Check if tunnel is already established""" if self.skip_tunnel_checkup: self.tunnel_is_up[_srv.local_address] = True return @@ -1139,25 +1148,22 @@ def _check_tunnel(self, _srv): s.settimeout(TUNNEL_TIMEOUT) try: # Windows raises WinError 10049 if trying to connect to 0.0.0.0 - connect_to = ('127.0.0.1', _srv.local_port) \ - if _srv.local_host == '0.0.0.0' else _srv.local_address + connect_to = ( + ('127.0.0.1', _srv.local_port) + if _srv.local_host == '0.0.0.0' + else _srv.local_address + ) s.connect(connect_to) self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get( timeout=TUNNEL_TIMEOUT * 1.1 ) - self.logger.debug( - 'Tunnel to %s is DOWN', _srv.remote_address - ) + self.logger.debug('Tunnel to %s is DOWN', _srv.remote_address) except OSError: - self.logger.debug( - 'Tunnel to %s is DOWN', _srv.remote_address - ) + self.logger.debug('Tunnel to %s is DOWN', _srv.remote_address) self.tunnel_is_up[_srv.local_address] = False except queue.Empty: - self.logger.debug( - 'Tunnel to %s is UP', _srv.remote_address - ) + self.logger.debug('Tunnel to %s is UP', _srv.remote_address) self.tunnel_is_up[_srv.local_address] = True finally: s.close() @@ -1180,18 +1186,23 @@ def _make_ssh_forward_handler_class(self, remote_address_): """ Make SSH Handler class """ + class Handler(_ForwardHandler): remote_address = remote_address_ ssh_transport = self._transport logger = self.logger + return Handler def _make_ssh_forward_server_class(self, remote_address_): return _ThreadingForwardServer if self._threaded else _ForwardServer def _make_stream_ssh_forward_server_class(self, remote_address_): - return _ThreadingStreamForwardServer if self._threaded \ + return ( + _ThreadingStreamForwardServer + if self._threaded else _StreamForwardServer + ) def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None): if self._raise_fwd_exc: @@ -1204,9 +1215,11 @@ def _make_ssh_forward_server(self, remote_address, local_bind_address): """ _Handler = self._make_ssh_forward_handler_class(remote_address) try: - forward_maker_class = self._make_stream_ssh_forward_server_class \ - if isinstance(local_bind_address, string_types) \ + forward_maker_class = ( + self._make_stream_ssh_forward_server_class + if isinstance(local_bind_address, string_types) else self._make_ssh_forward_server_class + ) _Server = forward_maker_class(remote_address) ssh_forward_server = _Server( local_bind_address, @@ -1223,22 +1236,24 @@ def _make_ssh_forward_server(self, remote_address, local_bind_address): BaseSSHTunnelForwarderError, 'Problem setting up ssh {0} <> {1} forwarder. You can ' 'suppress this exception by using the `mute_exceptions`' - 'argument'.format(address_to_str(local_bind_address), - address_to_str(remote_address)) + 'argument'.format( + address_to_str(local_bind_address), + address_to_str(remote_address), + ), ) except OSError: self._raise( BaseSSHTunnelForwarderError, "Couldn't open tunnel {0} <> {1} might be in use or " - "destination not reachable".format( + 'destination not reachable'.format( address_to_str(local_bind_address), - address_to_str(remote_address) - ) + address_to_str(remote_address), + ), ) @staticmethod def get_agent_keys(logger=None): - """ Load public keys from any available SSH agent + """Load public keys from any available SSH agent Arguments: logger (Optional[logging.Logger]) @@ -1278,15 +1293,20 @@ def get_keys( # noqa: C901 too complex list """ - keys = SSHTunnelForwarder.get_agent_keys(logger=logger) \ - if allow_agent else [] + keys = ( + SSHTunnelForwarder.get_agent_keys(logger=logger) + if allow_agent + else [] + ) if host_pkey_directories is None: host_pkey_directories = [DEFAULT_SSH_DIRECTORY] - paramiko_key_types = {'rsa': paramiko.RSAKey, - 'dsa': paramiko.DSSKey, - 'ecdsa': paramiko.ECDSAKey} + paramiko_key_types = { + 'rsa': paramiko.RSAKey, + 'dsa': paramiko.DSSKey, + 'ecdsa': paramiko.ECDSAKey, + } if hasattr(paramiko, 'Ed25519Key'): paramiko_key_types['ed25519'] = paramiko.Ed25519Key for directory in host_pkey_directories: @@ -1299,7 +1319,7 @@ def get_keys( # noqa: C901 too complex ssh_pkey = SSHTunnelForwarder.read_private_key_file( pkey_file=ssh_pkey_expanded, logger=logger, - key_type=value + key_type=value, ) if ssh_pkey: keys.append(ssh_pkey) @@ -1308,14 +1328,14 @@ def get_keys( # noqa: C901 too complex logger.warning( 'Private key file %s check error: %s', ssh_pkey_expanded, - exc + exc, ) if logger: logger.info('%s key(s) loaded', len(keys)) return keys def _get_transport(self): - """ Return the SSH transport to the remote gateway """ + """Return the SSH transport to the remote gateway""" if self.ssh_proxy: if isinstance(self.ssh_proxy, paramiko.proxy.ProxyCommand): proxy_repr = repr(self.ssh_proxy.cmd[1]) @@ -1343,7 +1363,7 @@ def _get_transport(self): self.logger.debug( 'Transport socket info: %s, timeout=%s', sock_info, - sock_timeout + sock_timeout, ) return transport @@ -1356,11 +1376,13 @@ def _check_is_started(self): raise HandlerSSHTunnelForwarderError(msg) def _stop_transport(self, force=False): - """ Close the underlying transport when nothing more is needed """ + """Close the underlying transport when nothing more is needed""" try: self._check_is_started() - except (BaseSSHTunnelForwarderError, - HandlerSSHTunnelForwarderError) as e: + except ( + BaseSSHTunnelForwarderError, + HandlerSSHTunnelForwarderError, + ) as e: self.logger.warning(e) if force and self.is_active: # don't wait connections @@ -1373,7 +1395,7 @@ def _stop_transport(self, force=False): 'Shutting down tunnel: %s <> %s (%s)', address_to_str(_srv.local_address), address_to_str(_srv.remote_address), - status + status, ) _srv.shutdown() _srv.server_close() @@ -1385,7 +1407,7 @@ def _stop_transport(self, force=False): self.logger.error( 'Unable to unlink socket %s: %s', _srv.local_address, - repr(e) + repr(e), ) self.is_alive = False if self.is_active: @@ -1405,14 +1427,15 @@ def _connect_to_gateway(self): """ for key in self.ssh_pkeys: self.logger.debug( - 'Trying to log in with key: %s', - hexlify(key.get_fingerprint()) + 'Trying to log in with key: %s', hexlify(key.get_fingerprint()) ) try: self._transport = self._get_transport() - self._transport.connect(hostkey=self.ssh_host_key, - username=self.ssh_username, - pkey=key) + self._transport.connect( + hostkey=self.ssh_host_key, + username=self.ssh_username, + pkey=key, + ) if self._transport.is_alive: return except paramiko.AuthenticationException: @@ -1422,13 +1445,15 @@ def _connect_to_gateway(self): if self.ssh_password: # avoid conflict using both pass and pkey self.logger.debug( 'Trying to log in with password: %s', - '*' * len(self.ssh_password) + '*' * len(self.ssh_password), ) try: self._transport = self._get_transport() - self._transport.connect(hostkey=self.ssh_host_key, - username=self.ssh_username, - password=self.ssh_password) + self._transport.connect( + hostkey=self.ssh_host_key, + username=self.ssh_username, + password=self.ssh_password, + ) if self._transport.is_alive: return except paramiko.AuthenticationException: @@ -1447,7 +1472,7 @@ def _create_tunnels(self): except socket.gaierror: # raised by paramiko.Transport self.logger.error( 'Could not resolve IP address for %s, aborting!', - self.ssh_host + self.ssh_host, ) return except (OSError, paramiko.SSHException) as e: @@ -1455,23 +1480,21 @@ def _create_tunnels(self): 'Could not connect to gateway %s:%s : %s', self.ssh_host, self.ssh_port, - e.args[0] + e.args[0], ) return - for (rem, loc) in zip(self._remote_binds, self._local_binds): + for rem, loc in zip(self._remote_binds, self._local_binds): try: self._make_ssh_forward_server(rem, loc) except BaseSSHTunnelForwarderError as e: self.logger.error( - 'Problem setting SSH Forwarder up: %s', - e.value + 'Problem setting SSH Forwarder up: %s', e.value ) @staticmethod - def read_private_key_file(pkey_file, - pkey_password=None, - key_type=None, - logger=None): + def read_private_key_file( + pkey_file, pkey_password=None, key_type=None, logger=None + ): """ Get SSH Public key from a private key file, given an optional password @@ -1491,18 +1514,17 @@ def read_private_key_file(pkey_file, ssh_pkey = None key_types = (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey) if hasattr(paramiko, 'Ed25519Key'): - key_types += (paramiko.Ed25519Key, ) + key_types += (paramiko.Ed25519Key,) for pkey_class in (key_type,) if key_type else key_types: try: ssh_pkey = pkey_class.from_private_key_file( - pkey_file, - password=pkey_password + pkey_file, password=pkey_password ) if logger: logger.debug( 'Private key file (%s, %s) successfully loaded', pkey_file, - pkey_class + pkey_class, ) break except paramiko.PasswordRequiredException: @@ -1517,7 +1539,7 @@ def read_private_key_file(pkey_file, pkey_file, 'could not be loaded as type', pkey_class, - 'or bad password' + 'or bad password', ) return ssh_pkey @@ -1528,38 +1550,42 @@ def _serve_forever_wrapper(self, _srv, poll_interval=0.1): self.logger.info( 'Opening tunnel: %s <> %s', address_to_str(_srv.local_address), - address_to_str(_srv.remote_address) + address_to_str(_srv.remote_address), ) _srv.serve_forever(poll_interval) # blocks until finished self.logger.info( 'Tunnel: %s <> %s released', address_to_str(_srv.local_address), - address_to_str(_srv.remote_address) + address_to_str(_srv.remote_address), ) def start(self): - """ Start the SSH tunnels """ + """Start the SSH tunnels""" if self.is_alive: self.logger.warning('Already started!') return self._create_tunnels() if not self.is_active: - self._raise(BaseSSHTunnelForwarderError, - reason='Could not establish session to SSH gateway') + self._raise( + BaseSSHTunnelForwarderError, + reason='Could not establish session to SSH gateway', + ) for _srv in self._server_list: thread = threading.Thread( target=self._serve_forever_wrapper, - args=(_srv, ), - name='Srv-{0}'.format(address_to_str(_srv.local_port)) + args=(_srv,), + name='Srv-{0}'.format(address_to_str(_srv.local_port)), ) thread.daemon = self.daemon_forward_servers thread.start() self._check_tunnel(_srv) self.is_alive = any(self.tunnel_is_up.values()) if not self.is_alive: - self._raise(HandlerSSHTunnelForwarderError, - 'An error occurred while opening tunnels.') + self._raise( + HandlerSSHTunnelForwarderError, + 'An error occurred while opening tunnels.', + ) def stop(self, force=False): """ @@ -1591,20 +1617,23 @@ def stop(self, force=False): """ self.logger.info('Closing all open connections...') - opened_address_text = ', '.join( - address_to_str(k.local_address) for k in self._server_list - ) or 'None' + opened_address_text = ( + ', '.join( + address_to_str(k.local_address) for k in self._server_list + ) + or 'None' + ) self.logger.debug('Listening tunnels: %s', opened_address_text) self._stop_transport(force=force) self._server_list = [] # reset server list self.tunnel_is_up = {} # reset tunnel status def close(self): - """ Stop the an active tunnel, alias to :meth:`.stop` """ + """Stop the an active tunnel, alias to :meth:`.stop`""" self.stop() def restart(self): - """ Restart connection to the gateway and tunnels """ + """Restart connection to the gateway and tunnels""" self.stop() self.start() @@ -1614,9 +1643,7 @@ def local_bind_port(self): self._check_is_started() if len(self._server_list) != 1: msg = 'Use .local_bind_ports property for more than one tunnel' - raise BaseSSHTunnelForwarderError( - msg - ) + raise BaseSSHTunnelForwarderError(msg) return self.local_bind_ports[0] @property @@ -1625,9 +1652,7 @@ def local_bind_host(self): self._check_is_started() if len(self._server_list) != 1: msg = 'Use .local_bind_hosts property for more than one tunnel' - raise BaseSSHTunnelForwarderError( - msg - ) + raise BaseSSHTunnelForwarderError(msg) return self.local_bind_hosts[0] @property @@ -1636,9 +1661,7 @@ def local_bind_address(self): self._check_is_started() if len(self._server_list) != 1: msg = 'Use .local_bind_addresses property for more than one tunnel' - raise BaseSSHTunnelForwarderError( - msg - ) + raise BaseSSHTunnelForwarderError(msg) return self.local_bind_addresses[0] @property @@ -1647,8 +1670,11 @@ def local_bind_ports(self): Return a list containing the ports of local side of the TCP tunnels """ self._check_is_started() - return [_server.local_port for _server in self._server_list if - _server.local_port is not None] + return [ + _server.local_port + for _server in self._server_list + if _server.local_port is not None + ] @property def local_bind_hosts(self): @@ -1656,8 +1682,11 @@ def local_bind_hosts(self): Return a list containing the IP addresses listening for the tunnels """ self._check_is_started() - return [_server.local_host for _server in self._server_list if - _server.local_host is not None] + return [ + _server.local_host + for _server in self._server_list + if _server.local_host is not None + ] @property def local_bind_addresses(self): @@ -1680,10 +1709,9 @@ def tunnel_bindings(self): @property def is_active(self): - """ Return True if the underlying SSH transport is up """ + """Return True if the underlying SSH transport is up""" return bool( - '_transport' in self.__dict__ - and self._transport.is_active() + '_transport' in self.__dict__ and self._transport.is_active() ) def __exit__(self, *args): @@ -1701,26 +1729,33 @@ def __enter__(self): def __str__(self): credentials = { 'password': self.ssh_password, - 'pkeys': [(key.get_name(), hexlify(key.get_fingerprint())) - for key in self.ssh_pkeys] - if any(self.ssh_pkeys) else None + 'pkeys': [ + (key.get_name(), hexlify(key.get_fingerprint())) + for key in self.ssh_pkeys + ] + if any(self.ssh_pkeys) + else None, } _remove_none_values(credentials) - template = os.linesep.join(['{0} object', - 'ssh gateway: {1}:{2}', - 'proxy: {3}', - 'username: {4}', - 'authentication: {5}', - 'hostkey: {6}', - 'status: {7}started', - 'keepalive messages: {8}', - 'tunnel connection check: {9}', - 'concurrent connections: {10}allowed', - 'compression: {11}requested', - 'logging level: {12}', - 'local binds: {13}', - 'remote binds: {14}']) - return (template.format( + template = os.linesep.join( + [ + '{0} object', + 'ssh gateway: {1}:{2}', + 'proxy: {3}', + 'username: {4}', + 'authentication: {5}', + 'hostkey: {6}', + 'status: {7}started', + 'keepalive messages: {8}', + 'tunnel connection check: {9}', + 'concurrent connections: {10}allowed', + 'compression: {11}requested', + 'logging level: {12}', + 'local binds: {13}', + 'remote binds: {14}', + ] + ) + return template.format( self.__class__, self.ssh_host, self.ssh_port, @@ -1729,15 +1764,16 @@ def __str__(self): credentials, self.ssh_host_key or 'not checked', '' if self.is_alive else 'not ', - 'disabled' if not self.set_keepalive else - 'every {0} sec'.format(self.set_keepalive), + 'disabled' + if not self.set_keepalive + else 'every {0} sec'.format(self.set_keepalive), 'disabled' if self.skip_tunnel_checkup else 'enabled', '' if self._threaded else 'not ', '' if self.compression else 'not ', logging.getLevelName(self.logger.level), self._local_binds, self._remote_binds, - )) + ) def __repr__(self): return self.__str__() @@ -1763,13 +1799,13 @@ def open_tunnel(*args, **kwargs): ssh_port=22, ssh_password=SSH_PASSWORD, remote_bind_address=(REMOTE_HOST, REMOTE_PORT), - local_bind_address=('', LOCAL_PORT) + local_bind_address=('', LOCAL_PORT), ) as server: def do_something(port): pass - print("LOCAL PORTS:", server.local_bind_port) + print('LOCAL PORTS:', server.local_bind_port) do_something(server.local_bind_port) Arguments: @@ -1800,25 +1836,25 @@ def do_something(port): # Check if deprecated arguments ssh_address or ssh_host were used for deprecated_argument in ['ssh_address', 'ssh_host']: ssh_address_or_host = SSHTunnelForwarder._process_deprecated( - ssh_address_or_host, - deprecated_argument, - kwargs + ssh_address_or_host, deprecated_argument, kwargs ) ssh_port = kwargs.pop('ssh_port', 22) skip_tunnel_checkup = kwargs.pop('skip_tunnel_checkup', True) block_on_close = kwargs.pop('block_on_close', None) if block_on_close: - warnings.warn("'block_on_close' is DEPRECATED. You should use either" - " .stop() or .stop(force=True), depends on what you do" - " with the active connections. This option has no" - " affect since 0.3.0", - DeprecationWarning) + warnings.warn( + "'block_on_close' is DEPRECATED. You should use either" + ' .stop() or .stop(force=True), depends on what you do' + ' with the active connections. This option has no' + ' affect since 0.3.0', + DeprecationWarning, + ) if not args: if isinstance(ssh_address_or_host, tuple): - args = (ssh_address_or_host, ) + args = (ssh_address_or_host,) else: - args = ((ssh_address_or_host, ssh_port), ) + args = ((ssh_address_or_host, ssh_port),) forwarder = SSHTunnelForwarder(*args, **kwargs) forwarder.skip_tunnel_checkup = skip_tunnel_checkup return forwarder @@ -1854,9 +1890,7 @@ def _bindlist(input_str): return _ip, int(_port) except ValueError: msg = 'Address tuple must be of type IP_ADDRESS:PORT' - raise argparse.ArgumentTypeError( - msg - ) + raise argparse.ArgumentTypeError(msg) except AssertionError: msg = "Both IP:PORT can't be missing!" raise argparse.ArgumentTypeError(msg) @@ -1867,43 +1901,48 @@ def _parse_arguments(args=None): Parse arguments directly passed from CLI """ parser = argparse.ArgumentParser( - description='Pure python ssh tunnel utils\n' - 'Version {0}'.format(__version__), - formatter_class=argparse.RawTextHelpFormatter + description='Pure python ssh tunnel utils\nVersion {0}'.format( + __version__ + ), + formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( 'ssh_address', type=str, help='SSH server IP address (GW for SSH tunnels)\n' - 'set with "-- ssh_address" if immediately after ' - '-R or -L' + 'set with "-- ssh_address" if immediately after ' + '-R or -L', ) parser.add_argument( - '-U', '--username', + '-U', + '--username', type=str, dest='ssh_username', - help='SSH server account username' + help='SSH server account username', ) parser.add_argument( - '-p', '--server_port', + '-p', + '--server_port', type=int, dest='ssh_port', default=22, - help='SSH server TCP port (default: 22)' + help='SSH server TCP port (default: 22)', ) parser.add_argument( - '-P', '--password', + '-P', + '--password', type=str, dest='ssh_password', - help='SSH server account password' + help='SSH server account password', ) parser.add_argument( - '-R', '--remote_bind_address', + '-R', + '--remote_bind_address', type=_bindlist, nargs='+', default=[], @@ -1911,105 +1950,114 @@ def _parse_arguments(args=None): required=True, dest='remote_bind_addresses', help='Remote bind address sequence: ' - 'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n' - 'Equivalent to ssh -Lxxxx:IP_ADDRESS:PORT\n' - 'If port is omitted, defaults to 22.\n' - 'Example: -R 10.10.10.10: 10.10.10.10:5900' + 'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n' + 'Equivalent to ssh -Lxxxx:IP_ADDRESS:PORT\n' + 'If port is omitted, defaults to 22.\n' + 'Example: -R 10.10.10.10: 10.10.10.10:5900', ) parser.add_argument( - '-L', '--local_bind_address', + '-L', + '--local_bind_address', type=_bindlist, nargs='*', dest='local_bind_addresses', metavar='IP:PORT', help='Local bind address sequence: ' - 'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n' - 'Elements may also be valid UNIX socket domains: \n' - '/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n' - 'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, ' - 'being the local IP address optional.\n' - 'By default it will listen in all interfaces ' - '(0.0.0.0) and choose a random port.\n' - 'Example: -L :40000' + 'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n' + 'Elements may also be valid UNIX socket domains: \n' + '/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n' + 'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, ' + 'being the local IP address optional.\n' + 'By default it will listen in all interfaces ' + '(0.0.0.0) and choose a random port.\n' + 'Example: -L :40000', ) parser.add_argument( - '-k', '--ssh_host_key', - type=str, - help="Gateway's host key" + '-k', '--ssh_host_key', type=str, help="Gateway's host key" ) parser.add_argument( - '-K', '--private_key_file', + '-K', + '--private_key_file', dest='ssh_private_key', metavar='KEY_FILE', type=str, - help='RSA/DSS/ECDSA private key file' + help='RSA/DSS/ECDSA private key file', ) parser.add_argument( - '-S', '--private_key_password', + '-S', + '--private_key_password', dest='ssh_private_key_password', metavar='KEY_PASSWORD', type=str, - help='RSA/DSS/ECDSA private key password' + help='RSA/DSS/ECDSA private key password', ) parser.add_argument( - '-t', '--threaded', + '-t', + '--threaded', action='store_true', - help='Allow concurrent connections to each tunnel' + help='Allow concurrent connections to each tunnel', ) parser.add_argument( - '-v', '--verbose', + '-v', + '--verbose', action='count', default=0, help='Increase output verbosity (default: {0})'.format( logging.getLevelName(DEFAULT_LOGLEVEL) - ) + ), ) parser.add_argument( - '-V', '--version', + '-V', + '--version', action='version', version='%(prog)s {version}'.format(version=__version__), - help='Show version number and quit' + help='Show version number and quit', ) parser.add_argument( - '-x', '--proxy', + '-x', + '--proxy', type=_bindlist, dest='ssh_proxy', metavar='IP:PORT', - help='IP and port of SSH proxy to destination' + help='IP and port of SSH proxy to destination', ) parser.add_argument( - '-c', '--config', + '-c', + '--config', type=str, default=SSH_CONFIG_FILE, dest='ssh_config_file', - help='SSH configuration file, defaults to {0}'.format(SSH_CONFIG_FILE) + help='SSH configuration file, defaults to {0}'.format(SSH_CONFIG_FILE), ) parser.add_argument( - '-z', '--compress', + '-z', + '--compress', action='store_true', dest='compression', - help='Request server for compression over SSH transport' + help='Request server for compression over SSH transport', ) parser.add_argument( - '-n', '--noagent', + '-n', + '--noagent', action='store_false', dest='allow_agent', - help='Disable looking for keys from an SSH agent' + help='Disable looking for keys from an SSH agent', ) parser.add_argument( - '-d', '--host_pkey_directories', + '-d', + '--host_pkey_directories', nargs='*', dest='host_pkey_directories', metavar='FOLDER', diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 1f686b8..c8e2f95 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -32,6 +32,7 @@ # UTILS + def get_random_string(length=12): """ @@ -105,6 +106,7 @@ def capture_stdout_stderr(): # TESTS + class MockLoggingHandler(logging.Handler, object): """Mock logging handler to check for expected logs. @@ -113,8 +115,14 @@ class MockLoggingHandler(logging.Handler, object): """ def __init__(self, *args, **kwargs): - self.messages = {'debug': [], 'info': [], 'warning': [], 'error': [], - 'critical': [], 'trace': []} + self.messages = { + 'debug': [], + 'info': [], + 'warning': [], + 'error': [], + 'critical': [], + 'trace': [], + } super(MockLoggingHandler, self).__init__(*args, **kwargs) def emit(self, record): @@ -157,11 +165,11 @@ def get_allowed_auths(self, username): return allowed_auths def check_auth_password(self, username, password): - _ok = (username == SSH_USERNAME and password == SSH_PASSWORD) + _ok = username == SSH_USERNAME and password == SSH_PASSWORD self.log.debug( 'NullServer >> password for %s %sOK', username, - '' if _ok else 'NOT-' + '' if _ok else 'NOT-', ) return paramiko.AUTH_SUCCESSFUL if _ok else paramiko.AUTH_FAILED @@ -177,7 +185,7 @@ def check_auth_publickey(self, username, key): self.log.debug( 'NullServer >> pkey authentication for %s %sOK', username, - '' if _ok else 'NOT-' + '' if _ok else 'NOT-', ) return paramiko.AUTH_SUCCESSFUL if _ok else paramiko.AUTH_FAILED @@ -203,7 +211,7 @@ def check_channel_direct_tcpip_request(self, chanid, origin, destination): 'check_channel_direct_tcpip_request', chanid, origin, - destination + destination, ) return paramiko.OPEN_SUCCEEDED @@ -221,8 +229,7 @@ def setUpClass(cls): super(SSHClientTest, cls).setUpClass() socket.setdefaulttimeout(sshtunnel.SSH_TIMEOUT) cls.log = logging.getLogger(sshtunnel.__name__) - cls.log = sshtunnel.create_logger(logger=cls.log, - loglevel='DEBUG') + cls.log = sshtunnel.create_logger(logger=cls.log, loglevel='DEBUG') cls._sshtunnel_log_handler = MockLoggingHandler(level='DEBUG') cls.log.addHandler(cls._sshtunnel_log_handler) cls.sshtunnel_log_messages = cls._sshtunnel_log_handler.messages @@ -240,8 +247,8 @@ def setUp(self): self.log.info('setUp for: %s()', self._testMethodName.upper()) self.ssockl, self.saddr, self.sport = self.make_socket() self.esockl, self.eaddr, self.eport = self.make_socket() - self.log.info("Socket for ssh-server: %s:%s", self.saddr, self.sport) - self.log.info("Socket for echo-server: %s:%s", self.eaddr, self.eport) + self.log.info('Socket for ssh-server: %s:%s', self.saddr, self.sport) + self.log.info('Socket for echo-server: %s:%s', self.eaddr, self.eport) self.ssh_event = threading.Event() self.running_threads = [] @@ -258,14 +265,13 @@ def tearDown(self): self.log.info( 'thread %s (%s)', thread, - 'alive' if x.is_alive() else 'defunct' + 'alive' if x.is_alive() else 'defunct', ) while self.running_threads: for thread in self.running_threads: x = self.threads[thread] - self.wait_for_thread(self.threads[thread], - who='tearDown') + self.wait_for_thread(self.threads[thread], who='tearDown') if not x.is_alive(): self.log.info('thread %s now stopped', thread) @@ -286,16 +292,12 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): self.ssh_event.wait(THREADS_TIMEOUT) # wait for SSH server's transport try: schan = self.ts.accept(timeout=timeout) - info = "forward-server schan <> echo" - self.log.info("%s accept()", info) - echo = socket.create_connection( - (self.eaddr, self.eport) - ) + info = 'forward-server schan <> echo' + self.log.info('%s accept()', info) + echo = socket.create_connection((self.eaddr, self.eport)) while self.is_server_working: - rqst, _, _ = select.select([schan, echo], - [], - [], - timeout) + # On Windows, only sockets are supported + rqst, _, _ = select.select([schan, echo], [], [], timeout) if schan in rqst: data = schan.recv(1024) self.log.debug('%s -->: %s', info, repr(data)) @@ -334,26 +336,21 @@ def _run_ssh_server(self): get_test_data_path(PKEY_FILE) ) self.ts.add_server_key(host_key) - server = NullServer(allowed_keys=FINGERPRINTS.keys(), - log=self.log) - t = threading.Thread(target=self._do_forwarding, - name='forward-server') + server = NullServer(allowed_keys=FINGERPRINTS.keys(), log=self.log) + t = threading.Thread(target=self._do_forwarding, name='forward-server') t.daemon = DAEMON_THREADS self.running_threads.append(t.name) self.threads[t.name] = t t.start() self.ts.start_server(self.ssh_event, server) - self.wait_for_thread(t, - timeout=None, - who='ssh-server') + self.wait_for_thread(t, timeout=None, who='ssh-server') self.log.info('ssh-server shutting down') self.running_threads.remove('ssh-server') def start_echo_and_ssh_server(self): self.is_server_working = True self.start_echo_server() - t = threading.Thread(target=self._run_ssh_server, - name='ssh-server') + t = threading.Thread(target=self._run_ssh_server, name='ssh-server') t.daemon = DAEMON_THREADS self.running_threads.append(t.name) self.threads[t.name] = t @@ -368,8 +365,7 @@ def _check_server_auth(self): self.ssh_event.wait(sshtunnel.SSH_TIMEOUT) # wait for transport self.assertTrue(self.ssh_event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual(self.ts.get_username(), - SSH_USERNAME) + self.assertEqual(self.ts.get_username(), SSH_USERNAME) self.assertTrue(self.ts.is_authenticated()) @contextmanager @@ -387,10 +383,7 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): socks = [self.esockl] try: while self.is_server_working: - inputready, _, _ = select.select(socks, - [], - [], - timeout) + inputready, _, _ = select.select(socks, [], [], timeout) for s in inputready: if s == self.esockl: # handle the server socket @@ -428,8 +421,7 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): self.running_threads.remove('echo-server') def start_echo_server(self): - t = threading.Thread(target=self._run_echo_server, - name='echo-server') + t = threading.Thread(target=self._run_echo_server, name='echo-server') t.daemon = DAEMON_THREADS self.running_threads.append(t.name) self.threads[t.name] = t @@ -451,17 +443,16 @@ def test_echo_server(self): self.log.info('_test_server(): try connect!') s = socket.create_connection(local_bind_addr) self.log.info( - '_test_server(): connected from %s! try send!', - s.getsockname() + '_test_server(): connected from %s! try send!', s.getsockname() ) s.send(message) self.log.info('_test_server(): sent!') - z = (s.recv(1000)) + z = s.recv(1000) self.assertEqual(z, message) s.close() def test_connect_by_username_password(self): - """ Test connecting using username/password as authentication """ + """Test connecting using username/password as authentication""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -472,7 +463,7 @@ def test_connect_by_username_password(self): pass # no exceptions are raised def test_connect_by_rsa_key_file(self): - """ Test connecting using a RSA key file """ + """Test connecting using a RSA key file""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -483,7 +474,7 @@ def test_connect_by_rsa_key_file(self): pass # no exceptions are raised def test_connect_by_paramiko_key(self): - """ Test connecting when ssh_private_key is a paramiko.RSAKey """ + """Test connecting when ssh_private_key is a paramiko.RSAKey""" ssh_key = paramiko.RSAKey.from_private_key_file( get_test_data_path(PKEY_FILE) ) @@ -497,7 +488,7 @@ def test_connect_by_paramiko_key(self): pass def test_open_tunnel(self): - """ Test wrapper method mainly used from CLI """ + """Test wrapper method mainly used from CLI""" server = sshtunnel.open_tunnel( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -567,7 +558,7 @@ def test_unknown_argument_raises_exception(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - i_do_not_exist=0 + i_do_not_exist=0, ) def test_more_local_than_remote_bind_sizes_raises_exception(self): @@ -581,8 +572,10 @@ def test_more_local_than_remote_bind_sizes_raises_exception(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_addresses=[('127.0.0.1', self.eport), - ('127.0.0.1', self.randomize_eport())] + local_bind_addresses=[ + ('127.0.0.1', self.eport), + ('127.0.0.1', self.randomize_eport()), + ], ) def test_localbindaddress_and_localbindaddresses_mutually_exclusive(self): @@ -597,8 +590,10 @@ def test_localbindaddress_and_localbindaddresses_mutually_exclusive(self): ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), local_bind_address=('127.0.0.1', self.eport), - local_bind_addresses=[('127.0.0.1', self.eport), - ('127.0.0.1', self.randomize_eport())] + local_bind_addresses=[ + ('127.0.0.1', self.eport), + ('127.0.0.1', self.randomize_eport()), + ], ) def test_localbindaddress_host_is_optional(self): @@ -611,7 +606,7 @@ def test_localbindaddress_host_is_optional(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('', self.randomize_eport()) + local_bind_address=('', self.randomize_eport()), ) as server: self.assertEqual(server.local_bind_host, '0.0.0.0') @@ -625,7 +620,7 @@ def test_localbindaddress_port_is_optional(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('127.0.0.1', ) + local_bind_address=('127.0.0.1',), ) as server: self.assertIsInstance(server.local_bind_port, int) @@ -640,8 +635,10 @@ def test_remotebindaddress_and_remotebindaddresses_are_exclusive(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - remote_bind_addresses=[(self.eaddr, self.eport), - (self.eaddr, self.randomize_eport())] + remote_bind_addresses=[ + (self.eaddr, self.eport), + (self.eaddr, self.randomize_eport()), + ], ) def test_no_remote_bind_address_raises_exception(self): @@ -655,8 +652,10 @@ def test_no_remote_bind_address_raises_exception(self): ssh_username=SSH_USERNAME, ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_reading_from_a_bad_sshconfigfile_does_not_raise_error(self): """ Test that when a bad ssh_config file is found, a warning is shown @@ -671,7 +670,7 @@ def test_reading_from_a_bad_sshconfigfile_does_not_raise_error(self): remote_bind_address=(self.eaddr, self.eport), local_bind_address=('127.0.0.1', self.randomize_eport()), logger=self.log, - ssh_config_file=ssh_config_file + ssh_config_file=ssh_config_file, ) logged_message = 'Could not read SSH configuration file: {0}'.format( ssh_config_file @@ -688,11 +687,10 @@ def test_not_setting_password_or_pkey_raises_error(self): (self.saddr, self.sport), ssh_username=SSH_USERNAME, remote_bind_address=(self.eaddr, self.eport), - ssh_config_file=None + ssh_config_file=None, ) - @unittest.skipIf(os.name == 'nt', - reason='Need to fix test on Windows') + @unittest.skipIf(os.name == 'nt', reason='Need to fix test on Windows') def test_deprecate_warnings_are_shown(self): """Test that when using deprecate arguments a warning is logged""" warnings.simplefilter('always') # don't ignore DeprecationWarnings @@ -706,9 +704,11 @@ def test_deprecate_warnings_are_shown(self): 'remote_bind_address': (self.eaddr, self.eport), } open_tunnel(**_kwargs) - logged_message = "'{0}' is DEPRECATED use '{1}' instead"\ - .format(deprecated_arg, - sshtunnel._DEPRECATIONS[deprecated_arg]) + logged_message = ( + "'{0}' is DEPRECATED use '{1}' instead".format( + deprecated_arg, sshtunnel._DEPRECATIONS[deprecated_arg] + ) + ) self.assertTrue(issubclass(w[-1].category, DeprecationWarning)) self.assertEqual(logged_message, str(w[-1].message)) @@ -716,7 +716,7 @@ def test_deprecate_warnings_are_shown(self): with warnings.catch_warnings(record=True) as w: for deprecated_arg in [ 'raise_exception_if_any_forwarder_have_a_problem', - 'ssh_private_key' + 'ssh_private_key', ]: _kwargs = { 'ssh_address_or_host': (self.saddr, self.sport), @@ -726,9 +726,11 @@ def test_deprecate_warnings_are_shown(self): deprecated_arg: (self.saddr, self.sport), } open_tunnel(**_kwargs) - logged_message = "'{0}' is DEPRECATED use '{1}' instead"\ - .format(deprecated_arg, - sshtunnel._DEPRECATIONS[deprecated_arg]) + logged_message = ( + "'{0}' is DEPRECATED use '{1}' instead".format( + deprecated_arg, sshtunnel._DEPRECATIONS[deprecated_arg] + ) + ) self.assertTrue(issubclass(w[-1].category, DeprecationWarning)) self.assertEqual(logged_message, str(w[-1].message)) @@ -749,8 +751,10 @@ def test_gateway_unreachable_raises_exception(self): ): pass - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_gateway_ip_unresolvable_raises_exception(self): """ BaseSSHTunnelForwarderError is raised when not able to resolve the @@ -769,29 +773,35 @@ def test_gateway_ip_unresolvable_raises_exception(self): 'Could not resolve IP address for {0}, aborting!'.format( SSH_USERNAME ), - self.sshtunnel_log_messages['error'] + self.sshtunnel_log_messages['error'], ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_running_start_twice_logs_warning(self): """Test that when running start() twice a warning is shown""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_address=(self.eaddr, self.eport) + remote_bind_address=(self.eaddr, self.eport), ) as server: - self.assertNotIn('Already started!', - self.sshtunnel_log_messages['warning']) + self.assertNotIn( + 'Already started!', self.sshtunnel_log_messages['warning'] + ) server.logger.error(server.is_active) server.logger.error(server.is_alive) server.start() # 2nd start should prompt the warning - self.assertIn('Already started!', - self.sshtunnel_log_messages['warning']) + self.assertIn( + 'Already started!', self.sshtunnel_log_messages['warning'] + ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_stop_before_start_logs_warning(self): """ Test that running .stop() on an already stopped server logs a warning @@ -805,11 +815,15 @@ def test_stop_before_start_logs_warning(self): logger=self.log, ) server.stop() - self.assertIn('Server is not started. Please .start() first!', - self.sshtunnel_log_messages['warning']) + self.assertIn( + 'Server is not started. Please .start() first!', + self.sshtunnel_log_messages['warning'], + ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_wrong_auth_to_gateway_logs_error(self): """ Test that when connecting to the ssh gateway with wrong credentials, @@ -824,11 +838,15 @@ def test_wrong_auth_to_gateway_logs_error(self): logger=self.log, ): pass - self.assertIn('Could not open connection to gateway', - self.sshtunnel_log_messages['error']) + self.assertIn( + 'Could not open connection to gateway', + self.sshtunnel_log_messages['error'], + ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_missing_pkey_file_logs_warning(self): """ Test that when the private key file is missing, a warning is logged @@ -842,13 +860,16 @@ def test_missing_pkey_file_logs_warning(self): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ): - self.assertIn('Private key file not found: {0}'.format(bad_pkey), - self.sshtunnel_log_messages['warning']) + self.assertIn( + 'Private key file not found: {0}'.format(bad_pkey), + self.sshtunnel_log_messages['warning'], + ) def test_connect_via_proxy(self): - """ Test connecting using a ProxyCommand """ - proxycmd = paramiko.proxy.ProxyCommand('ssh proxy -W {0}:{1}' - .format(self.saddr, self.sport)) + """Test connecting using a ProxyCommand""" + proxycmd = paramiko.proxy.ProxyCommand( + 'ssh proxy -W {0}:{1}'.format(self.saddr, self.sport) + ) server = open_tunnel( self.saddr, ssh_username=SSH_USERNAME, @@ -860,10 +881,12 @@ def test_connect_via_proxy(self): ) self.assertEqual(server.ssh_proxy.cmd[1], 'proxy') - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_can_skip_loading_sshconfig(self): - """ Test that we can skip loading the ~/.ssh/config file """ + """Test that we can skip loading the ~/.ssh/config file""" server = open_tunnel( (self.saddr, self.sport), ssh_password=SSH_PASSWORD, @@ -872,11 +895,13 @@ def test_can_skip_loading_sshconfig(self): logger=self.log, ) self.assertEqual(server.ssh_username, getpass.getuser()) - self.assertIn('Skipping loading of ssh configuration file', - self.sshtunnel_log_messages['info']) + self.assertIn( + 'Skipping loading of ssh configuration file', + self.sshtunnel_log_messages['info'], + ) def test_local_bind_port(self): - """ Test local_bind_port property """ + """Test local_bind_port property""" s = socket.socket() s.bind(('localhost', 0)) addr, port = s.getsockname() @@ -893,7 +918,7 @@ def test_local_bind_port(self): self.assertEqual(server.local_bind_port, port) def test_local_bind_host(self): - """ Test local_bind_host property """ + """Test local_bind_host property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -906,7 +931,7 @@ def test_local_bind_host(self): self.assertEqual(server.local_bind_host, self.saddr) def test_local_bind_address(self): - """ Test local_bind_address property """ + """Test local_bind_address property""" s = socket.socket() s.bind(('localhost', 0)) addr, port = s.getsockname() @@ -923,13 +948,15 @@ def test_local_bind_address(self): self.assertTupleEqual(server.local_bind_address, (addr, port)) def test_local_bind_ports(self): - """ Test local_bind_ports property """ + """Test local_bind_ports property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[ + (self.eaddr, self.eport), + (self.saddr, self.sport), + ], logger=self.log, ) as server: self.assertIsInstance(server.local_bind_ports, list) @@ -947,44 +974,50 @@ def test_local_bind_ports(self): self.assertIsInstance(server.local_bind_ports, list) def test_local_bind_hosts(self): - """ Test local_bind_hosts property """ + """Test local_bind_hosts property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, local_bind_addresses=[(self.saddr, 0)] * 2, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[ + (self.eaddr, self.eport), + (self.saddr, self.sport), + ], logger=self.log, ) as server: self.assertIsInstance(server.local_bind_hosts, list) - self.assertListEqual(server.local_bind_hosts, - [self.saddr] * 2) + self.assertListEqual(server.local_bind_hosts, [self.saddr] * 2) with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): self.log.info(server.local_bind_host) def test_local_bind_addresses(self): - """ Test local_bind_addresses property """ + """Test local_bind_addresses property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, local_bind_addresses=[(self.saddr, 0)] * 2, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[ + (self.eaddr, self.eport), + (self.saddr, self.sport), + ], logger=self.log, ) as server: self.assertIsInstance(server.local_bind_addresses, list) - self.assertListEqual(server.local_bind_addresses, - list(zip([self.saddr] * 2, - server.local_bind_ports))) + self.assertListEqual( + server.local_bind_addresses, + list(zip([self.saddr] * 2, server.local_bind_ports)), + ) with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): self.log.info(server.local_bind_address) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_check_tunnels(self): - """ Test method checking if tunnels are up """ + """Test method checking if tunnels are up""" remote_address = (self.eaddr, self.eport) with self._test_server( (self.saddr, self.sport), @@ -994,49 +1027,65 @@ def test_check_tunnels(self): logger=self.log, skip_tunnel_checkup=False, ) as server: - self.assertIn('Tunnel to {0} is UP'.format(remote_address), - self.sshtunnel_log_messages['debug']) + self.assertIn( + 'Tunnel to {0} is UP'.format(remote_address), + self.sshtunnel_log_messages['debug'], + ) server.check_tunnels() - self.assertIn('Tunnel to {0} is DOWN'.format(remote_address), - self.sshtunnel_log_messages['debug']) + self.assertIn( + 'Tunnel to {0} is DOWN'.format(remote_address), + self.sshtunnel_log_messages['debug'], + ) # Calling local_is_up() should also return the same server.skip_tunnel_checkup = True server.local_is_up((self.saddr, self.sport)) - self.assertIn('Tunnel to {0} is DOWN'.format(remote_address), - self.sshtunnel_log_messages['debug']) + self.assertIn( + 'Tunnel to {0} is DOWN'.format(remote_address), + self.sshtunnel_log_messages['debug'], + ) - self.assertFalse(server.local_is_up("not a valid address")) - self.assertIn('Target must be a tuple (IP, port), where IP ' - 'is a string (i.e. "192.168.0.1") and port is ' - 'an integer (i.e. 40000). Alternatively ' - 'target can be a valid UNIX domain socket.', - self.sshtunnel_log_messages['warning']) + self.assertFalse(server.local_is_up('not a valid address')) + self.assertIn( + 'Target must be a tuple (IP, port), where IP ' + 'is a string (i.e. "192.168.0.1") and port is ' + 'an integer (i.e. 40000). Alternatively ' + 'target can be a valid UNIX domain socket.', + self.sshtunnel_log_messages['warning'], + ) @mock.patch('sshtunnel.input_', return_value=linesep) def test_cli_main_exits_when_pressing_enter(self, input): - """ Test that _cli_main() function quits when Enter is pressed """ + """Test that _cli_main() function quits when Enter is pressed""" self.start_echo_and_ssh_server() - sshtunnel._cli_main(args=[self.saddr, - '-U', SSH_USERNAME, - '-P', SSH_PASSWORD, - '-p', str(self.sport), - '-R', '{0}:{1}'.format(self.eaddr, - self.eport), - '-c', '', - '-n'], - host_pkey_directories=[]) + sshtunnel._cli_main( + args=[ + self.saddr, + '-U', + SSH_USERNAME, + '-P', + SSH_PASSWORD, + '-p', + str(self.sport), + '-R', + '{0}:{1}'.format(self.eaddr, self.eport), + '-c', + '', + '-n', + ], + host_pkey_directories=[], + ) self.stop_echo_and_ssh_server() - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_read_private_key_file(self): - """ Test that an encrypted private key can be opened """ + """Test that an encrypted private key can be opened""" encr_pkey = get_test_data_path(ENCRYPTED_PKEY_FILE) pkey = sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - pkey_password='sshtunnel', - logger=self.log + encr_pkey, pkey_password='sshtunnel', logger=self.log ) _pkey = paramiko.RSAKey.from_private_key_file( get_test_data_path(PKEY_FILE) @@ -1044,27 +1093,33 @@ def test_read_private_key_file(self): self.assertEqual(pkey, _pkey) # Using a wrong password returns None - self.assertIsNone(sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - pkey_password='bad password', - logger=self.log - )) - self.assertIn("Private key file ({0}) could not be loaded as type " - "{1} or bad password" - .format(encr_pkey, type(_pkey)), - self.sshtunnel_log_messages['debug']) + self.assertIsNone( + sshtunnel.SSHTunnelForwarder.read_private_key_file( + encr_pkey, pkey_password='bad password', logger=self.log + ) + ) + self.assertIn( + 'Private key file ({0}) could not be loaded as type ' + '{1} or bad password'.format(encr_pkey, type(_pkey)), + self.sshtunnel_log_messages['debug'], + ) # Using no password on an encrypted key returns None - self.assertIsNone(sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - logger=self.log - )) - self.assertIn('Password is required for key {0}'.format(encr_pkey), - self.sshtunnel_log_messages['error']) - - @unittest.skipIf(os.name != 'posix', - reason="UNIX sockets not supported on this platform") + self.assertIsNone( + sshtunnel.SSHTunnelForwarder.read_private_key_file( + encr_pkey, logger=self.log + ) + ) + self.assertIn( + 'Password is required for key {0}'.format(encr_pkey), + self.sshtunnel_log_messages['error'], + ) + + @unittest.skipIf( + os.name != 'posix', + reason='UNIX sockets not supported on this platform', + ) def test_unix_domains(self): - """ Test use of UNIX domain sockets in local binds """ + """Test use of UNIX domain sockets in local binds""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -1075,14 +1130,15 @@ def test_unix_domains(self): ) as server: self.assertEqual(server.local_bind_address, TEST_UNIX_SOCKET) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") + @unittest.skipIf( + sys.version_info < (2, 7), + reason='Cannot intercept logging messages in py26', + ) def test_tracing_logging(self): """ Test that Tracing mode may be enabled for more fine-grained logs """ - logger = sshtunnel.create_logger(logger=self.log, - loglevel='TRACE') + logger = sshtunnel.create_logger(logger=self.log, loglevel='TRACE') with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -1090,8 +1146,9 @@ def test_tracing_logging(self): remote_bind_address=(self.eaddr, self.eport), logger=logger, ) as server: - server.logger = sshtunnel.create_logger(logger=server.logger, - loglevel='TRACE') + server.logger = sshtunnel.create_logger( + logger=server.logger, loglevel='TRACE' + ) message = get_random_string(100).encode() # Windows raises WinError 10049 if trying to connect to 0.0.0.0 s = socket.create_connection(('127.0.0.1', server.local_bind_port)) @@ -1100,11 +1157,11 @@ def test_tracing_logging(self): s.close log = 'send to {0}'.format((self.eaddr, self.eport)) - self.assertTrue(any(log in msg for msg in - self.sshtunnel_log_messages['trace'])) + self.assertTrue( + any(log in msg for msg in self.sshtunnel_log_messages['trace']) + ) # set loglevel back to the original value - logger = sshtunnel.create_logger(logger=self.log, - loglevel='DEBUG') + logger = sshtunnel.create_logger(logger=self.log, loglevel='DEBUG') def test_tunnel_bindings_contain_active_tunnels(self): """ @@ -1116,20 +1173,24 @@ def test_tunnel_bindings_contain_active_tunnels(self): (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_addresses=[(self.eaddr, remote_ports[0]), - (self.eaddr, remote_ports[1])], - local_bind_addresses=[('127.0.0.1', local_ports[0]), - ('127.0.0.1', local_ports[1])], + remote_bind_addresses=[ + (self.eaddr, remote_ports[0]), + (self.eaddr, remote_ports[1]), + ], + local_bind_addresses=[ + ('127.0.0.1', local_ports[0]), + ('127.0.0.1', local_ports[1]), + ], skip_tunnel_checkup=False, ) as server: self.assertListEqual(server.local_bind_ports, local_ports) self.assertTupleEqual( server.tunnel_bindings[(self.eaddr, remote_ports[0])], - ('127.0.0.1', local_ports[0]) + ('127.0.0.1', local_ports[0]), ) self.assertTupleEqual( server.tunnel_bindings[(self.eaddr, remote_ports[1])], - ('127.0.0.1', local_ports[1]) + ('127.0.0.1', local_ports[1]), ) def check_make_ssh_forward_server_sets_daemon(self, case): @@ -1165,19 +1226,23 @@ def test_make_ssh_forward_server_sets_daemon_false(self): self.check_make_ssh_forward_server_sets_daemon(case=False) def test_get_keys(self): - """ Test loading keys from the paramiko Agent """ + """Test loading keys from the paramiko Agent""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), local_bind_address=('', self.randomize_eport()), - logger=self.log + logger=self.log, ) as server: keys = server.get_keys(logger=self.log) self.assertIsInstance(keys, list) - self.assertFalse(any('keys loaded from agent' in msg for msg in - self.sshtunnel_log_messages['info'])) + self.assertFalse( + any( + 'keys loaded from agent' in msg + for msg in self.sshtunnel_log_messages['info'] + ) + ) with self._test_server( (self.saddr, self.sport), @@ -1185,41 +1250,53 @@ def test_get_keys(self): ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), local_bind_address=('', self.randomize_eport()), - logger=self.log + logger=self.log, ) as server: keys = server.get_keys(logger=self.log, allow_agent=True) self.assertIsInstance(keys, list) - self.assertTrue(any('keys loaded from agent' in msg for msg in - self.sshtunnel_log_messages['info'])) + self.assertTrue( + any( + 'keys loaded from agent' in msg + for msg in self.sshtunnel_log_messages['info'] + ) + ) tmp_dir = tempfile.mkdtemp() - shutil.copy(get_test_data_path(PKEY_FILE), - os.path.join(tmp_dir, 'id_rsa')) + shutil.copy( + get_test_data_path(PKEY_FILE), os.path.join(tmp_dir, 'id_rsa') + ) keys = sshtunnel.SSHTunnelForwarder.get_keys( self.log, - host_pkey_directories=[tmp_dir, ] + host_pkey_directories=[ + tmp_dir, + ], ) self.assertIsInstance(keys, list) self.assertTrue( - any('1 key(s) loaded' in msg - for msg in self.sshtunnel_log_messages['info']) + any( + '1 key(s) loaded' in msg + for msg in self.sshtunnel_log_messages['info'] + ) ) shutil.rmtree(tmp_dir) class AuxiliaryTest(unittest.TestCase): - """ Set of tests that do not need the mock SSH server or logger """ + """Set of tests that do not need the mock SSH server or logger""" def _test_parser(self, parser): self.assertEqual(parser['ssh_address'], '10.10.10.10') self.assertEqual(parser['ssh_username'], getpass.getuser()) self.assertEqual(parser['ssh_port'], 22) self.assertEqual(parser['ssh_password'], SSH_PASSWORD) - self.assertListEqual(parser['remote_bind_addresses'], - [('10.0.0.1', 8080), ('10.0.0.2', 8080)]) - self.assertListEqual(parser['local_bind_addresses'], - [('', 8081), ('', 8082)]) + self.assertListEqual( + parser['remote_bind_addresses'], + [('10.0.0.1', 8080), ('10.0.0.2', 8080)], + ) + self.assertListEqual( + parser['local_bind_addresses'], [('', 8081), ('', 8082)] + ) self.assertEqual(parser['ssh_host_key'], str(SSH_DSS)) self.assertEqual(parser['ssh_private_key'], __file__) self.assertEqual(parser['ssh_private_key_password'], SSH_PASSWORD) @@ -1231,23 +1308,28 @@ def _test_parser(self, parser): self.assertFalse(parser['allow_agent']) def test_parse_arguments_short(self): - """ Test CLI argument parsing with short parameter names """ - args = ['10.10.10.10', # ssh_address - '-U={0}'.format(getpass.getuser()), # GW username - '-p=22', # GW SSH port - '-P={0}'.format(SSH_PASSWORD), # GW password - '-R', '10.0.0.1:8080', '10.0.0.2:8080', # remote bind list - '-L', ':8081', ':8082', # local bind list - '-k={0}'.format(SSH_DSS), # hostkey - '-K={0}'.format(__file__), # pkey file - '-S={0}'.format(SSH_PASSWORD), # pkey password - '-t', # concurrent connections (threaded) - '-vvv', # triple verbosity - '-x=10.0.0.2:', # proxy address - '-c=ssh_config', # ssh configuration file - '-z', # request compression - '-n', # disable SSH agent key lookup - ] + """Test CLI argument parsing with short parameter names""" + args = [ + '10.10.10.10', # ssh_address + '-U={0}'.format(getpass.getuser()), # GW username + '-p=22', # GW SSH port + '-P={0}'.format(SSH_PASSWORD), # GW password + '-R', + '10.0.0.1:8080', + '10.0.0.2:8080', # remote bind list + '-L', + ':8081', + ':8082', # local bind list + '-k={0}'.format(SSH_DSS), # hostkey + '-K={0}'.format(__file__), # pkey file + '-S={0}'.format(SSH_PASSWORD), # pkey password + '-t', # concurrent connections (threaded) + '-vvv', # triple verbosity + '-x=10.0.0.2:', # proxy address + '-c=ssh_config', # ssh configuration file + '-z', # request compression + '-n', # disable SSH agent key lookup + ] parser = sshtunnel._parse_arguments(args) self._test_parser(parser) @@ -1260,24 +1342,33 @@ def test_parse_arguments_short(self): parser = sshtunnel._parse_arguments(args[:4] + args[5:]) def test_parse_arguments_long(self): - """ Test CLI argument parsing with long parameter names """ + """Test CLI argument parsing with long parameter names""" parser = sshtunnel._parse_arguments( - ['10.10.10.10', # ssh_address - '--username={0}'.format(getpass.getuser()), # GW username - '--server_port=22', # GW SSH port - '--password={0}'.format(SSH_PASSWORD), # GW password - '--remote_bind_address', '10.0.0.1:8080', '10.0.0.2:8080', - '--local_bind_address', ':8081', ':8082', # local bind list - '--ssh_host_key={0}'.format(SSH_DSS), # hostkey - '--private_key_file={0}'.format(__file__), # pkey file - '--private_key_password={0}'.format(SSH_PASSWORD), - '--threaded', # concurrent connections (threaded) - '--verbose', '--verbose', '--verbose', # triple verbosity - '--proxy', '10.0.0.2:22', # proxy address - '--config', 'ssh_config', # ssh configuration file - '--compress', # request compression - '--noagent', # disable SSH agent key lookup - ] + [ + '10.10.10.10', # ssh_address + '--username={0}'.format(getpass.getuser()), # GW username + '--server_port=22', # GW SSH port + '--password={0}'.format(SSH_PASSWORD), # GW password + '--remote_bind_address', + '10.0.0.1:8080', + '10.0.0.2:8080', + '--local_bind_address', + ':8081', + ':8082', # local bind list + '--ssh_host_key={0}'.format(SSH_DSS), # hostkey + '--private_key_file={0}'.format(__file__), # pkey file + '--private_key_password={0}'.format(SSH_PASSWORD), + '--threaded', # concurrent connections (threaded) + '--verbose', + '--verbose', + '--verbose', # triple verbosity + '--proxy', + '10.0.0.2:22', # proxy address + '--config', + 'ssh_config', # ssh configuration file + '--compress', # request compression + '--noagent', # disable SSH agent key lookup + ] ) self._test_parser(parser) @@ -1285,20 +1376,23 @@ def test_bindlist(self): """ Test that _bindlist enforces IP:PORT format for local and remote binds """ - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1:8080'), - ('10.0.0.1', 8080)) + self.assertTupleEqual( + sshtunnel._bindlist('10.0.0.1:8080'), ('10.0.0.1', 8080) + ) # Missing port in tuple is filled with port 22 - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1:'), - ('10.0.0.1', 22)) - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1'), - ('10.0.0.1', 22)) + self.assertTupleEqual( + sshtunnel._bindlist('10.0.0.1:'), ('10.0.0.1', 22) + ) + self.assertTupleEqual( + sshtunnel._bindlist('10.0.0.1'), ('10.0.0.1', 22) + ) with self.assertRaises(argparse.ArgumentTypeError): sshtunnel._bindlist('10022:10.0.0.1:22') with self.assertRaises(argparse.ArgumentTypeError): sshtunnel._bindlist(':') def test_raise_fwd_ext(self): - """ Test that we can silence the exceptions on sshtunnel creation """ + """Test that we can silence the exceptions on sshtunnel creation""" server = open_tunnel( '10.10.10.10', ssh_username=SSH_USERNAME, @@ -1314,7 +1408,7 @@ def test_raise_fwd_ext(self): server._raise(sshtunnel.BaseSSHTunnelForwarderError, 'test') def test_show_running_version(self): - """ Test that _cli_main() function quits when Enter is pressed """ + """Test that _cli_main() function quits when Enter is pressed""" with capture_stdout_stderr() as (out, err): with self.assertRaises(SystemExit): sshtunnel._cli_main(args=['-V']) @@ -1322,26 +1416,26 @@ def test_show_running_version(self): version = err.getvalue().split()[-1] else: version = out.getvalue().split()[-1] - self.assertEqual(version, - sshtunnel.__version__) + self.assertEqual(version, sshtunnel.__version__) def test_remove_none_values(self): - """ Test removing keys from a dict where values are None """ + """Test removing keys from a dict where values are None""" test_dict = {'key1': 1, 'key2': None, 'key3': 3, 'key4': 0} sshtunnel._remove_none_values(test_dict) - self.assertDictEqual(test_dict, - {'key1': 1, 'key3': 3, 'key4': 0}) + self.assertDictEqual(test_dict, {'key1': 1, 'key3': 3, 'key4': 0}) def test_read_ssh_config(self): - """ Test that we can gather host information from a config file """ - (ssh_hostname, - ssh_username, - ssh_private_key, - ssh_port, - ssh_proxy, - compression) = sshtunnel.SSHTunnelForwarder._read_ssh_config( - 'test', - get_test_data_path(TEST_CONFIG_FILE), + """Test that we can gather host information from a config file""" + ( + ssh_hostname, + ssh_username, + ssh_private_key, + ssh_port, + ssh_proxy, + compression, + ) = sshtunnel.SSHTunnelForwarder._read_ssh_config( + 'test', + get_test_data_path(TEST_CONFIG_FILE), ) self.assertEqual(ssh_hostname, 'test') self.assertEqual(ssh_username, 'test') @@ -1351,15 +1445,15 @@ def test_read_ssh_config(self): self.assertTrue(compression) # passed parameters are not overriden by config - (ssh_hostname, - ssh_username, - ssh_private_key, - ssh_port, - ssh_proxy, - compression) = sshtunnel.SSHTunnelForwarder._read_ssh_config( - 'other', - get_test_data_path(TEST_CONFIG_FILE), - compression=False + ( + ssh_hostname, + ssh_username, + ssh_private_key, + ssh_port, + ssh_proxy, + compression, + ) = sshtunnel.SSHTunnelForwarder._read_ssh_config( + 'other', get_test_data_path(TEST_CONFIG_FILE), compression=False ) self.assertEqual(ssh_hostname, '10.0.0.1') self.assertEqual(ssh_port, 222) @@ -1379,36 +1473,35 @@ def test_str(self): self.assertIn('status: not started', _str) def test_process_deprecations(self): - """ Test processing deprecated API attributes """ - kwargs = {'ssh_host': '10.0.0.1', - 'ssh_address': '10.0.0.1', - 'ssh_private_key': 'testrsa.key', - 'raise_exception_if_any_forwarder_have_a_problem': True} + """Test processing deprecated API attributes""" + kwargs = { + 'ssh_host': '10.0.0.1', + 'ssh_address': '10.0.0.1', + 'ssh_private_key': 'testrsa.key', + 'raise_exception_if_any_forwarder_have_a_problem': True, + } for item, value in kwargs.items(): self.assertEqual( value, sshtunnel.SSHTunnelForwarder._process_deprecated( - None, - item, - kwargs.copy() - ) + None, item, kwargs.copy() + ), ) # use both deprecated and not None new attribute should raise exception for item in kwargs: with self.assertRaises(ValueError): - sshtunnel.SSHTunnelForwarder._process_deprecated('some value', - item, - kwargs.copy()) + sshtunnel.SSHTunnelForwarder._process_deprecated( + 'some value', item, kwargs.copy() + ) # deprecated attribute not in deprecation list should raise exception with self.assertRaises(ValueError): - sshtunnel.SSHTunnelForwarder._process_deprecated('some value', - 'item', - kwargs.copy()) + sshtunnel.SSHTunnelForwarder._process_deprecated( + 'some value', 'item', kwargs.copy() + ) def test_check_address(self): - """ Test that an exception is raised with incorrect bind addresses """ - address_list = [('10.0.0.1', 10000), - ('10.0.0.1', 10001)] + """Test that an exception is raised with incorrect bind addresses""" + address_list = [('10.0.0.1', 10000), ('10.0.0.1', 10001)] if os.name == 'posix': # UNIX sockets supported by the platform address_list.append('/tmp/unix-socket') # UNIX sockets not supported on remote addresses From 9b61117136e45f1e1ec0d2f83c64ebad2d9d0b3b Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 18:50:24 -0500 Subject: [PATCH 45/46] expand caught errors --- sshtunnel.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 0d69c4c..a90f689 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -396,7 +396,7 @@ def handle(self): self.logger.log(TRACE_LEVEL, '%s connected', self.info) try: self._redirect(chan) - except OSError: + except (OSError, socket.error): # Sometimes a RST is sent and a socket error is raised, treat this # exception. It was seen that a 3way FIN is processed later on, so # no need to make an ordered close of the connection here or raise @@ -832,7 +832,7 @@ def _read_ssh_config( if compression is None: compression = hostname_info.get('compression', '') compression = compression.upper() == 'YES' - except (IOError, AssertionError, OSError): + except (IOError, OSError): if logger: logger.warning( 'Could not read SSH configuration file: %s', @@ -1158,7 +1158,7 @@ def _check_tunnel(self, _srv): timeout=TUNNEL_TIMEOUT * 1.1 ) self.logger.debug('Tunnel to %s is DOWN', _srv.remote_address) - except OSError: + except (OSError, socket.error): self.logger.debug('Tunnel to %s is DOWN', _srv.remote_address) self.tunnel_is_up[_srv.local_address] = False @@ -1241,7 +1241,7 @@ def _make_ssh_forward_server(self, remote_address, local_bind_address): address_to_str(remote_address), ), ) - except OSError: + except (OSError, IOError): self._raise( BaseSSHTunnelForwarderError, "Couldn't open tunnel {0} <> {1} might be in use or " @@ -1475,7 +1475,7 @@ def _create_tunnels(self): self.ssh_host, ) return - except (OSError, paramiko.SSHException) as e: + except (OSError, paramiko.SSHException, socket.error) as e: self.logger.error( 'Could not connect to gateway %s:%s : %s', self.ssh_host, From 411e97de04904b1e48f82ca20d6ae28ba135b2ff Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 18:58:43 -0500 Subject: [PATCH 46/46] remove trailing comma from *kwargs --- sshtunnel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sshtunnel.py b/sshtunnel.py index a90f689..2dea3c7 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1006,7 +1006,7 @@ def __init__( allow_agent=True, # look for keys from an SSH agent host_pkey_directories=None, # look for keys in ~/.ssh *args, - **kwargs, # for backwards compatibility + **kwargs # for backwards compatibility ): self.logger = logger or create_logger()