From 2b1451889e7e68b2147738ce893458defa95b76c Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 5 Mar 2026 19:25:45 -0500 Subject: [PATCH 01/25] use pytest syntax, resolve test warnings, fix windows deprecation tests --- tests/requirements.txt | 2 +- tests/test_forwarder.py | 1206 +++++++++++++++++++++------------------ tox.ini | 15 +- 3 files changed, 667 insertions(+), 556 deletions(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 6a91ea46..c1d3e7d3 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,6 +1,6 @@ coveralls mock -pytest +pytest>=4 pytest-cov pytest-xdist twine diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 40662d08..23bb1e6a 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1,38 +1,40 @@ from __future__ import with_statement +import argparse +import getpass +import logging import os -import sys import random +import re import select +import shutil import socket -import getpass -import logging -import argparse -import warnings +import sys import threading -from os import path, linesep -from functools import partial +import warnings from contextlib import contextmanager +from functools import partial +from os import linesep, path +from typing import List, Tuple, Union import mock import paramiko +import pytest + import sshtunnel -import shutil -import tempfile if sys.version_info[0] == 2: from cStringIO import StringIO - if sys.version_info < (2, 7): - import unittest2 as unittest - else: - import unittest else: - import unittest from io import StringIO +sshtunnel.TUNNEL_TIMEOUT = 1 + + # UTILS + def get_random_string(length=12): """ >>> r = get_random_string(1) @@ -50,7 +52,7 @@ def get_random_string(length=12): def get_test_data_path(x): - return path.join(HERE, x) + return path.join(path.abspath(path.dirname(__file__)), x) @contextmanager @@ -90,7 +92,6 @@ def capture_stdout_stderr(): 'ecdsa-sha2-nistp256': ECDSA, } DAEMON_THREADS = False -HERE = path.abspath(path.dirname(__file__)) THREADS_TIMEOUT = 5.0 PKEY_FILE = 'testrsa.key' ENCRYPTED_PKEY_FILE = 'testrsa_encrypted.key' @@ -103,6 +104,7 @@ def capture_stdout_stderr(): # TESTS + class MockLoggingHandler(logging.Handler, object): """Mock logging handler to check for expected logs. @@ -111,12 +113,18 @@ class MockLoggingHandler(logging.Handler, object): """ def __init__(self, *args, **kwargs): - self.messages = {'debug': [], 'info': [], 'warning': [], 'error': [], - 'critical': [], 'trace': []} + self.messages = { + 'debug': [], + 'info': [], + 'warning': [], + 'error': [], + 'critical': [], + 'trace': [], + } super(MockLoggingHandler, self).__init__(*args, **kwargs) def emit(self, record): - "Store a message from ``record`` in the instance's ``messages`` dict." + """Store a message from ``record`` in ``self.messages`` dict.""" self.acquire() try: self.messages[record.levelname.lower()].append(record.getMessage()) @@ -140,135 +148,242 @@ def __init__(self, *args, **kwargs): super(NullServer, self).__init__(*args, **kwargs) def check_channel_forward_agent_request(self, channel): - self.log.debug('NullServer.check_channel_forward_agent_request() {0}' - .format(channel)) + self.log.debug( + 'NullServer.check_channel_forward_agent_request() {0}'.format( + channel + ) + ) return False def get_allowed_auths(self, username): allowed_auths = 'publickey{0}'.format( ',password' if username == SSH_USERNAME else '' ) - self.log.debug('NullServer >> allowed auths for {0}: {1}' - .format(username, allowed_auths)) + self.log.debug( + 'NullServer >> allowed auths for {0}: {1}'.format( + username, allowed_auths + ) + ) return allowed_auths def check_auth_password(self, username, password): - _ok = (username == SSH_USERNAME and password == SSH_PASSWORD) - self.log.debug('NullServer >> password for {0} {1}OK' - .format(username, '' if _ok else 'NOT-')) + _ok = username == SSH_USERNAME and password == SSH_PASSWORD + self.log.debug( + 'NullServer >> password for {0} {1}OK'.format( + username, '' if _ok else 'NOT-' + ) + ) return paramiko.AUTH_SUCCESSFUL if _ok else paramiko.AUTH_FAILED def check_auth_publickey(self, username, key): try: expected = FINGERPRINTS[key.get_name()] - _ok = (key.get_name() in self.__allowed_keys and - key.get_fingerprint() == expected) + _ok = ( + key.get_name() in self.__allowed_keys + and key.get_fingerprint() == expected + ) except KeyError: _ok = False - self.log.debug('NullServer >> pkey authentication for {0} {1}OK' - .format(username, '' if _ok else 'NOT-')) + self.log.debug( + 'NullServer >> pkey authentication for {0} {1}OK'.format( + username, '' if _ok else 'NOT-' + ) + ) return paramiko.AUTH_SUCCESSFUL if _ok else paramiko.AUTH_FAILED def check_channel_request(self, kind, chanid): - self.log.debug('NullServer.check_channel_request()') + self.log.debug( + 'NullServer.check_channel_request({0}, {1})'.format(kind, chanid) + ) return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - self.log.debug('NullServer.check_channel_exec_request()') + self.log.debug( + 'NullServer.check_channel_exec_request({0}, {1})'.format( + channel, command + ) + ) return True def check_port_forward_request(self, address, port): - self.log.debug('NullServer.check_port_forward_request()') + self.log.debug( + 'NullServer.check_port_forward_request({0}, {1})'.format( + address, port + ) + ) return True def check_global_request(self, kind, msg): - self.log.debug('NullServer.check_port_forward_request()') + self.log.debug( + 'NullServer.check_global_request(kind={0})'.format(kind) + ) return True def check_channel_direct_tcpip_request(self, chanid, origin, destination): - self.log.debug('NullServer.check_channel_direct_tcpip_request' - '(chanid={0}) {1} -> {2}' - .format(chanid, origin, destination)) + self.log.debug( + 'NullServer.check_channel_direct_tcpip_request' + '(chanid={0}) {1} -> {2}'.format(chanid, origin, destination) + ) return paramiko.OPEN_SUCCEEDED -class SSHClientTest(unittest.TestCase): - def make_socket(self): +class TestSSHClient: + @staticmethod + def make_socket(): s = socket.socket() s.bind(('localhost', 0)) s.listen(5) addr, port = s.getsockname() return s, addr, port - @classmethod - def setUpClass(cls): - super(SSHClientTest, cls).setUpClass() + @pytest.fixture(autouse=True, scope='function') + def setup_ssh_environment(self, request): socket.setdefaulttimeout(sshtunnel.SSH_TIMEOUT) - cls.log = logging.getLogger(sshtunnel.__name__) - cls.log = sshtunnel.create_logger(logger=cls.log, - loglevel='DEBUG') - cls._sshtunnel_log_handler = MockLoggingHandler(level='DEBUG') - cls.log.addHandler(cls._sshtunnel_log_handler) - cls.sshtunnel_log_messages = cls._sshtunnel_log_handler.messages + self.log = logging.getLogger(sshtunnel.__name__) + self.log = sshtunnel.create_logger(logger=self.log, loglevel='DEBUG') + + if not any( + isinstance(h, MockLoggingHandler) for h in self.log.handlers + ): + self._sshtunnel_log_handler = MockLoggingHandler(level='DEBUG') + self.log.addHandler(self._sshtunnel_log_handler) + else: + self._sshtunnel_log_handler = next( + h + for h in self.log.handlers + if isinstance(h, MockLoggingHandler) + ) + + self.sshtunnel_log_messages = self._sshtunnel_log_handler.messages # set verbose format for logging - _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' \ - '%(lineno)04d@%(module)-10.9s| %(message)s' - for handler in cls.log.handlers: + _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/%(lineno)04d@%(module)-10.9s| %(message)s' # noqa: E501 line-too-long + for handler in self.log.handlers: handler.setFormatter(logging.Formatter(_fmt)) - def setUp(self): - super(SSHClientTest, self).setUp() self.log.debug('*' * 80) - self.log.info('setUp for: {0}()'.format(self._testMethodName.upper())) + self.log.info('setUp for: {0}'.format(request.node.name.upper())) + self.ssockl, self.saddr, self.sport = self.make_socket() self.esockl, self.eaddr, self.eport = self.make_socket() - self.log.info("Socket for ssh-server: {0}:{1}" - .format(self.saddr, self.sport)) - self.log.info("Socket for echo-server: {0}:{1}" - .format(self.eaddr, self.eport)) - self.ssh_event = threading.Event() + self.log.info( + 'Socket for ssh-server: {0}:{1}'.format(self.saddr, self.sport) + ) + self.log.info( + 'Socket for echo-server: {0}:{1}'.format(self.eaddr, self.eport) + ) + + self.ssh_event = threading.Event() self.running_threads = [] self.threads = {} - self.is_server_working = False self._sshtunnel_log_handler.reset() - def tearDown(self): - self.log.info('tearDown for: {0}()' - .format(self._testMethodName.upper())) + yield + + self.log.info('tearDown for: {0}'.format(request.node.name.upper())) self.stop_echo_and_ssh_server() - for thread in self.running_threads: - x = self.threads[thread] - self.log.info('thread {0} ({1})' - .format(thread, - 'alive' if x.is_alive() else 'defunct')) - - while self.running_threads: - for thread in self.running_threads: - x = self.threads[thread] - self.wait_for_thread(self.threads[thread], - who='tearDown') + + for thread_name in list(self.running_threads): + x = self.threads.get(thread_name) + if x: + self.log.info( + 'thread {0} ({1})'.format( + thread_name, 'alive' if x.is_alive() else 'defunct' + ) + ) + self.wait_for_thread(x, who='tearDown') if not x.is_alive(): - self.log.info('thread {0} now stopped'.format(thread)) + self.log.info('thread {0} now stopped'.format(thread_name)) for attr in ['server', 'tc', 'ts', 'socks', 'ssockl', 'esockl']: - if hasattr(self, attr): - self.log.info('tearDown() {0}'.format(attr)) - getattr(self, attr).close() + val = getattr(self, attr, None) + if val and hasattr(val, 'close'): + self.log.info('tearDown() closing {0}'.format(attr)) + try: + val.close() + except (socket.error, OSError) as e: + self.log.debug('Error closing {0}: {1}'.format(attr, e)) def wait_for_thread(self, thread, timeout=THREADS_TIMEOUT, who=None): if thread.is_alive(): - self.log.debug('{0}waiting for {1} to end...' - .format('{0} '.format(who) if who else '', - thread.name)) + self.log.debug( + '{0}waiting for {1} to end...'.format( + '{0} '.format(who) if who else '', thread.name + ) + ) thread.join(timeout) + def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): + self.log.debug('forward-server Start') + self.ssh_event.wait(THREADS_TIMEOUT) # wait for SSH server's transport + info = "" + schan = None + echo = None + try: + schan = self.ts.accept(timeout=timeout) + info = 'forward-server schan <> echo' + self.log.info(info + ' accept()') + echo = socket.create_connection((self.eaddr, self.eport)) + while self.is_server_working: + inputs = [ + obj for obj in [schan, echo] if ( + obj is not None and hasattr(obj, 'fileno') + ) + ] + if len(inputs) < 2: + continue + rqst, _, _ = select.select(inputs, [], [], timeout) + if schan in rqst: + data = schan.recv(1024) + self.log.debug('{0} -->: {1}'.format(info, repr(data))) + echo.send(data) + if len(data) == 0: + break + if echo in rqst: + data = echo.recv(1024) + self.log.debug('{0} <--: {1}'.format(info, repr(data))) + schan.send(data) + if len(data) == 0: + break + self.log.info('<<< forward-server received STOP signal') + except socket.error: + self.log.critical('{0} sending RST'.format(info)) + finally: + if schan: + self.log.debug('{0} closing connection...'.format(info)) + schan.close() + echo.close() + self.log.debug('{0} connection closed.'.format(info)) + + def _run_ssh_server(self): + self.log.info('ssh-server Start') + try: + self.socks, addr = self.ssockl.accept() + except socket.timeout: + self.log.error('ssh-server connection timed out!') + self.running_threads.remove('ssh-server') + return + self.ts = paramiko.Transport(self.socks) + host_key = paramiko.RSAKey.from_private_key_file( + get_test_data_path(PKEY_FILE) + ) + self.ts.add_server_key(host_key) + server = NullServer(allowed_keys=FINGERPRINTS.keys(), log=self.log) + t = threading.Thread(target=self._do_forwarding, name='forward-server') + t.daemon = DAEMON_THREADS + self.running_threads.append(t.name) + self.threads[t.name] = t + t.start() + self.ts.start_server(self.ssh_event, server) + self.wait_for_thread(t, who='ssh-server') + self.log.info('ssh-server shutting down') + self.running_threads.remove('ssh-server') + def start_echo_and_ssh_server(self): self.is_server_working = True self.start_echo_server() - t = threading.Thread(target=self._run_ssh_server, - name='ssh-server') + t = threading.Thread(target=self._run_ssh_server, name='ssh-server') t.daemon = DAEMON_THREADS self.running_threads.append(t.name) self.threads[t.name] = t @@ -281,11 +396,10 @@ def stop_echo_and_ssh_server(self): def _check_server_auth(self): # Check if authentication to server was successfulZ self.ssh_event.wait(sshtunnel.SSH_TIMEOUT) # wait for transport - self.assertTrue(self.ssh_event.is_set()) - self.assertTrue(self.ts.is_active()) - self.assertEqual(self.ts.get_username(), - SSH_USERNAME) - self.assertTrue(self.ts.is_authenticated()) + assert self.ssh_event.is_set() + assert self.ts.is_active() + assert self.ts.get_username() == SSH_USERNAME + assert self.ts.is_authenticated() @contextmanager def _test_server(self, *args, **kwargs): @@ -296,59 +410,21 @@ def _test_server(self, *args, **kwargs): yield server server._stop_transport() - def start_echo_server(self): - t = threading.Thread(target=self._run_echo_server, - name='echo-server') - t.daemon = DAEMON_THREADS - self.running_threads.append(t.name) - self.threads[t.name] = t - t.start() - - def _run_ssh_server(self): - self.log.info('ssh-server Start') - try: - self.socks, addr = self.ssockl.accept() - except socket.timeout: - self.log.error('ssh-server connection timed out!') - self.running_threads.remove('ssh-server') - return - self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file( - get_test_data_path(PKEY_FILE) - ) - self.ts.add_server_key(host_key) - server = NullServer(allowed_keys=FINGERPRINTS.keys(), - log=self.log) - t = threading.Thread(target=self._do_forwarding, - name='forward-server') - t.daemon = DAEMON_THREADS - self.running_threads.append(t.name) - self.threads[t.name] = t - t.start() - self.ts.start_server(self.ssh_event, server) - self.wait_for_thread(t, - timeout=None, - who='ssh-server') - self.log.info('ssh-server shutting down') - self.running_threads.remove('ssh-server') - def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): self.log.info('echo-server Started') self.ssh_event.wait(timeout) # wait for transport socks = [self.esockl] try: while self.is_server_working: - inputready, _, _ = select.select(socks, - [], - [], - timeout) + inputready, _, _ = select.select(socks, [], [], timeout) for s in inputready: if s == self.esockl: # handle the server socket try: client, address = self.esockl.accept() - self.log.info('echo-server accept() {0}' - .format(address)) + self.log.info( + 'echo-server accept() {0}'.format(address) + ) except OSError: self.log.info('echo-server accept() OSError') break @@ -357,8 +433,9 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): # handle all other sockets try: data = s.recv(1000) - self.log.info('echo-server echoing {0}' - .format(data)) + self.log.info( + 'echo-server echoing {0}'.format(data) + ) s.send(data) except OSError: self.log.warning('echo-server OSError') @@ -373,54 +450,22 @@ def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): self.is_server_working = False if 'forward-server' in self.threads: t = self.threads['forward-server'] - self.wait_for_thread(t, timeout=None, who='echo-server') + self.wait_for_thread(t, who='echo-server') self.running_threads.remove('forward-server') for s in socks: s.close() self.log.info('echo-server shutting down') self.running_threads.remove('echo-server') - def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): - self.log.debug('forward-server Start') - self.ssh_event.wait(THREADS_TIMEOUT) # wait for SSH server's transport - try: - schan = self.ts.accept(timeout=timeout) - info = "forward-server schan <> echo" - self.log.info(info + " accept()") - echo = socket.create_connection( - (self.eaddr, self.eport) - ) - while self.is_server_working: - rqst, _, _ = select.select([schan, echo], - [], - [], - timeout) - if schan in rqst: - data = schan.recv(1024) - self.log.debug('{0} -->: {1}'.format(info, repr(data))) - echo.send(data) - if len(data) == 0: - break - if echo in rqst: - data = echo.recv(1024) - self.log.debug('{0} <--: {1}'.format(info, repr(data))) - schan.send(data) - if len(data) == 0: - break - self.log.info('<<< forward-server received STOP signal') - except socket.error: - self.log.critical('{0} sending RST'.format(info)) - # except Exception as e: - # # we reach this point usually when schan is None (paramiko bug?) - # self.log.critical(repr(e)) - finally: - if schan: - self.log.debug('{0} closing connection...'.format(info)) - schan.close() - echo.close() - self.log.debug('{0} connection closed.'.format(info)) + def start_echo_server(self): + t = threading.Thread(target=self._run_echo_server, name='echo-server') + t.daemon = DAEMON_THREADS + self.running_threads.append(t.name) + self.threads[t.name] = t + t.start() - def randomize_eport(self): + @staticmethod + def randomize_eport(): return random.randint(49152, 65535) def test_echo_server(self): @@ -435,16 +480,19 @@ def test_echo_server(self): local_bind_addr = ('127.0.0.1', server.local_bind_port) self.log.info('_test_server(): try connect!') s = socket.create_connection(local_bind_addr) - self.log.info('_test_server(): connected from {0}! try send!' - .format(s.getsockname())) + self.log.info( + '_test_server(): connected from {0}! try send!'.format( + s.getsockname() + ) + ) s.send(message) self.log.info('_test_server(): sent!') - z = (s.recv(1000)) - self.assertEqual(z, message) + z = s.recv(1000) + assert z == message s.close() def test_connect_by_username_password(self): - """ Test connecting using username/password as authentication """ + """Test connecting using username/password as authentication""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -455,7 +503,7 @@ def test_connect_by_username_password(self): pass # no exceptions are raised def test_connect_by_rsa_key_file(self): - """ Test connecting using a RSA key file """ + """Test connecting using a RSA key file""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -466,7 +514,7 @@ def test_connect_by_rsa_key_file(self): pass # no exceptions are raised def test_connect_by_paramiko_key(self): - """ Test connecting when ssh_private_key is a paramiko.RSAKey """ + """Test connecting when ssh_private_key is a paramiko.RSAKey""" ssh_key = paramiko.RSAKey.from_private_key_file( get_test_data_path(PKEY_FILE) ) @@ -480,7 +528,7 @@ def test_connect_by_paramiko_key(self): pass def test_open_tunnel(self): - """ Test wrapper method mainly used from CLI """ + """Test wrapper method mainly used from CLI""" server = sshtunnel.open_tunnel( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -491,22 +539,38 @@ def test_open_tunnel(self): allow_agent=False, host_pkey_directories=[], ) - self.assertEqual(server.ssh_host, self.saddr) - self.assertEqual(server.ssh_port, self.sport) - self.assertEqual(server.ssh_username, SSH_USERNAME) - self.assertEqual(server.ssh_password, SSH_PASSWORD) - self.assertEqual(server.logger, self.log) + assert server.ssh_host == self.saddr + assert server.ssh_port == self.sport + assert server.ssh_username == SSH_USERNAME + assert server.ssh_password == SSH_PASSWORD + assert server.logger == self.log self.start_echo_and_ssh_server() server.start() self._check_server_auth() server.stop() + def test_open_tunnel_block_on_close_deprecation(self): + """Ensure block_on_close keyword argument posts deprecation warning.""" + with pytest.warns( + DeprecationWarning, + match=re.escape( + "You should use either .stop() or .stop(force=True)" + ), + ): + sshtunnel.open_tunnel( + (self.saddr, self.sport), + ssh_username=SSH_USERNAME, + ssh_password=SSH_PASSWORD, + remote_bind_address=(self.eaddr, self.eport), + block_on_close=True, + ) + def test_sshaddress_and_sshaddressorhost_mutually_exclusive(self): """ Test that deprecate argument ssh_address cannot be used together with ssh_address_or_host """ - with self.assertRaises(ValueError): + with pytest.warns(DeprecationWarning), pytest.raises(ValueError): open_tunnel( ssh_address_or_host=(self.saddr, self.sport), ssh_address=(self.saddr, self.sport), @@ -520,7 +584,7 @@ def test_sshhost_and_sshaddressorhost_mutually_exclusive(self): Test that deprecate argument ssh_host cannot be used together with ssh_address_or_host """ - with self.assertRaises(ValueError): + with pytest.warns(DeprecationWarning), pytest.raises(ValueError): open_tunnel( ssh_address_or_host=(self.saddr, self.sport), ssh_host=(self.saddr, self.sport), @@ -540,17 +604,17 @@ def test_sshaddressorhost_may_not_be_a_tuple(self): ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), ) - self.assertEqual(server.ssh_port, 22) + assert server.ssh_port == 22 def test_unknown_argument_raises_exception(self): """Test that an exception is raised when setting an invalid argument""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): open_tunnel( self.saddr, ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - i_do_not_exist=0 + i_do_not_exist=0, ) def test_more_local_than_remote_bind_sizes_raises_exception(self): @@ -558,14 +622,16 @@ def test_more_local_than_remote_bind_sizes_raises_exception(self): Test that when the number of local_bind_addresses exceed number of remote_bind_addresses, an exception is raised """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): open_tunnel( self.saddr, ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_addresses=[('127.0.0.1', self.eport), - ('127.0.0.1', self.randomize_eport())] + local_bind_addresses=[ + ('127.0.0.1', self.eport), + ('127.0.0.1', self.randomize_eport()), + ], ) def test_localbindaddress_and_localbindaddresses_mutually_exclusive(self): @@ -573,15 +639,17 @@ def test_localbindaddress_and_localbindaddresses_mutually_exclusive(self): Test that arguments local_bind_address and local_bind_addresses cannot be used together """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): open_tunnel( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), local_bind_address=('127.0.0.1', self.eport), - local_bind_addresses=[('127.0.0.1', self.eport), - ('127.0.0.1', self.randomize_eport())] + local_bind_addresses=[ + ('127.0.0.1', self.eport), + ('127.0.0.1', self.randomize_eport()), + ], ) def test_localbindaddress_host_is_optional(self): @@ -594,9 +662,9 @@ def test_localbindaddress_host_is_optional(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('', self.randomize_eport()) + local_bind_address=('', self.randomize_eport()), ) as server: - self.assertEqual(server.local_bind_host, '0.0.0.0') + assert server.local_bind_host == '0.0.0.0' def test_localbindaddress_port_is_optional(self): """ @@ -608,23 +676,25 @@ def test_localbindaddress_port_is_optional(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('127.0.0.1', ) + local_bind_address=('127.0.0.1',), ) as server: - self.assertIsInstance(server.local_bind_port, int) + assert isinstance(server.local_bind_port, int) def test_remotebindaddress_and_remotebindaddresses_are_exclusive(self): """ Test that arguments remote_bind_address and remote_bind_addresses cannot be used together """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): open_tunnel( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - remote_bind_addresses=[(self.eaddr, self.eport), - (self.eaddr, self.randomize_eport())] + remote_bind_addresses=[ + (self.eaddr, self.eport), + (self.eaddr, self.randomize_eport()), + ], ) def test_no_remote_bind_address_raises_exception(self): @@ -632,14 +702,12 @@ def test_no_remote_bind_address_raises_exception(self): When no remote_bind_address or remote_bind_addresses are specified, a ValueError exception should be raised """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): open_tunnel( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_reading_from_a_bad_sshconfigfile_does_not_raise_error(self): """ Test that when a bad ssh_config file is found, a warning is shown @@ -654,75 +722,65 @@ def test_reading_from_a_bad_sshconfigfile_does_not_raise_error(self): remote_bind_address=(self.eaddr, self.eport), local_bind_address=('127.0.0.1', self.randomize_eport()), logger=self.log, - ssh_config_file=ssh_config_file + ssh_config_file=ssh_config_file, ) logged_message = 'Could not read SSH configuration file: {0}'.format( ssh_config_file ) - self.assertIn(logged_message, self.sshtunnel_log_messages['warning']) + assert logged_message in self.sshtunnel_log_messages['warning'] def test_not_setting_password_or_pkey_raises_error(self): """ Test that when a no authentication method is specified, an exception is raised """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): open_tunnel( (self.saddr, self.sport), ssh_username=SSH_USERNAME, remote_bind_address=(self.eaddr, self.eport), - ssh_config_file=None + ssh_config_file=None, ) - @unittest.skipIf(os.name == 'nt', - reason='Need to fix test on Windows') - def test_deprecate_warnings_are_shown(self): - """Test that when using deprecate arguments a warning is logged""" - warnings.simplefilter('always') # don't ignore DeprecationWarnings - - with warnings.catch_warnings(record=True) as w: - for deprecated_arg in ['ssh_address', 'ssh_host']: - _kwargs = { - deprecated_arg: (self.saddr, self.sport), - 'ssh_username': SSH_USERNAME, - 'ssh_password': SSH_PASSWORD, - 'remote_bind_address': (self.eaddr, self.eport), - } - open_tunnel(**_kwargs) - logged_message = "'{0}' is DEPRECATED use '{1}' instead"\ - .format(deprecated_arg, - sshtunnel._DEPRECATIONS[deprecated_arg]) - self.assertTrue(issubclass(w[-1].category, DeprecationWarning)) - self.assertEqual(logged_message, str(w[-1].message)) - - # other deprecated arguments - with warnings.catch_warnings(record=True) as w: - for deprecated_arg in [ - 'raise_exception_if_any_forwarder_have_a_problem', - 'ssh_private_key' - ]: - _kwargs = { - 'ssh_address_or_host': (self.saddr, self.sport), - 'ssh_username': SSH_USERNAME, - 'ssh_password': SSH_PASSWORD, - 'remote_bind_address': (self.eaddr, self.eport), - deprecated_arg: (self.saddr, self.sport), - } - open_tunnel(**_kwargs) - logged_message = "'{0}' is DEPRECATED use '{1}' instead"\ - .format(deprecated_arg, - sshtunnel._DEPRECATIONS[deprecated_arg]) - self.assertTrue(issubclass(w[-1].category, DeprecationWarning)) - self.assertEqual(logged_message, str(w[-1].message)) - - warnings.simplefilter('default') + @pytest.mark.parametrize( + 'deprecated_arg', + [ + 'ssh_address', + 'ssh_host', + 'raise_exception_if_any_forwarder_have_a_problem', + 'ssh_private_key', + ] + ) + def test_deprecation_warnings_are_shown(self, deprecated_arg): + """ + Deprecated arguments should log the correct DeprecationWarning. + """ + + replacement = sshtunnel._DEPRECATIONS[deprecated_arg] + expected_msg = ( + "'{0}' is DEPRECATED " + "use '{1}' instead" + ).format(deprecated_arg, replacement) + + _kwargs = { + 'ssh_username': SSH_USERNAME, + 'ssh_password': SSH_PASSWORD, + 'remote_bind_address': (self.eaddr, self.eport), + deprecated_arg: (self.saddr, self.sport), + } + + if deprecated_arg not in ('ssh_address', 'ssh_host'): + _kwargs['ssh_address_or_host'] = (self.saddr, self.sport) + + with pytest.warns(DeprecationWarning, match=expected_msg): + open_tunnel(**_kwargs) def test_gateway_unreachable_raises_exception(self): """ BaseSSHTunnelForwarderError is raised when not able to reach the ssh gateway """ - with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): + with pytest.raises(sshtunnel.BaseSSHTunnelForwarderError): with open_tunnel( (self.saddr, self.randomize_eport()), ssh_username=SSH_USERNAME, @@ -732,14 +790,12 @@ def test_gateway_unreachable_raises_exception(self): ): pass - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_gateway_ip_unresolvable_raises_exception(self): """ BaseSSHTunnelForwarderError is raised when not able to resolve the ssh gateway IP address """ - with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): + with pytest.raises(sshtunnel.BaseSSHTunnelForwarderError): with open_tunnel( (SSH_USERNAME, self.sport), ssh_username=SSH_USERNAME, @@ -748,33 +804,30 @@ def test_gateway_ip_unresolvable_raises_exception(self): ssh_config_file=None, ): pass - self.assertIn( + assert ( 'Could not resolve IP address for {0}, aborting!'.format( SSH_USERNAME - ), - self.sshtunnel_log_messages['error'] + ) + in self.sshtunnel_log_messages['error'] ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_running_start_twice_logs_warning(self): """Test that when running start() twice a warning is shown""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_address=(self.eaddr, self.eport) + remote_bind_address=(self.eaddr, self.eport), ) as server: - self.assertNotIn('Already started!', - self.sshtunnel_log_messages['warning']) + assert ( + 'Already started!' + not in self.sshtunnel_log_messages['warning'] + ) server.logger.error(server.is_active) server.logger.error(server.is_alive) server.start() # 2nd start should prompt the warning - self.assertIn('Already started!', - self.sshtunnel_log_messages['warning']) + assert 'Already started!' in self.sshtunnel_log_messages['warning'] - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_stop_before_start_logs_warning(self): """ Test that running .stop() on an already stopped server logs a warning @@ -788,17 +841,17 @@ def test_stop_before_start_logs_warning(self): logger=self.log, ) server.stop() - self.assertIn('Server is not started. Please .start() first!', - self.sshtunnel_log_messages['warning']) + assert ( + 'Server is not started. Please .start() first!' + in self.sshtunnel_log_messages['warning'] + ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_wrong_auth_to_gateway_logs_error(self): """ Test that when connecting to the ssh gateway with wrong credentials, an error is logged """ - with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): + with pytest.raises(sshtunnel.BaseSSHTunnelForwarderError): with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -807,11 +860,11 @@ def test_wrong_auth_to_gateway_logs_error(self): logger=self.log, ): pass - self.assertIn('Could not open connection to gateway', - self.sshtunnel_log_messages['error']) + assert ( + 'Could not open connection to gateway' + in self.sshtunnel_log_messages['error'] + ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_missing_pkey_file_logs_warning(self): """ Test that when the private key file is missing, a warning is logged @@ -825,13 +878,16 @@ def test_missing_pkey_file_logs_warning(self): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ): - self.assertIn('Private key file not found: {0}'.format(bad_pkey), - self.sshtunnel_log_messages['warning']) + assert ( + 'Private key file not found: {0}'.format(bad_pkey) + in self.sshtunnel_log_messages['warning'] + ) def test_connect_via_proxy(self): - """ Test connecting using a ProxyCommand """ - proxycmd = paramiko.proxy.ProxyCommand('ssh proxy -W {0}:{1}' - .format(self.saddr, self.sport)) + """Test connecting using a ProxyCommand""" + proxycmd = paramiko.proxy.ProxyCommand( + 'ssh proxy -W {0}:{1}'.format(self.saddr, self.sport) + ) server = open_tunnel( self.saddr, ssh_username=SSH_USERNAME, @@ -841,12 +897,10 @@ def test_connect_via_proxy(self): ssh_proxy_enabled=True, logger=self.log, ) - self.assertEqual(server.ssh_proxy.cmd[1], 'proxy') + assert server.ssh_proxy.cmd[1] == 'proxy' - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_can_skip_loading_sshconfig(self): - """ Test that we can skip loading the ~/.ssh/config file """ + """Test that we can skip loading the ~/.ssh/config file""" server = open_tunnel( (self.saddr, self.sport), ssh_password=SSH_PASSWORD, @@ -854,12 +908,14 @@ def test_can_skip_loading_sshconfig(self): ssh_config_file=None, logger=self.log, ) - self.assertEqual(server.ssh_username, getpass.getuser()) - self.assertIn('Skipping loading of ssh configuration file', - self.sshtunnel_log_messages['info']) + assert server.ssh_username == getpass.getuser() + assert ( + 'Skipping loading of ssh configuration file' + in self.sshtunnel_log_messages['info'] + ) def test_local_bind_port(self): - """ Test local_bind_port property """ + """Test local_bind_port property""" s = socket.socket() s.bind(('localhost', 0)) addr, port = s.getsockname() @@ -872,11 +928,11 @@ def test_local_bind_port(self): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ) as server: - self.assertIsInstance(server.local_bind_port, int) - self.assertEqual(server.local_bind_port, port) + assert isinstance(server.local_bind_port, int) + assert server.local_bind_port == port def test_local_bind_host(self): - """ Test local_bind_host property """ + """Test local_bind_host property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -885,11 +941,11 @@ def test_local_bind_host(self): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ) as server: - self.assertIsInstance(server.local_bind_host, str) - self.assertEqual(server.local_bind_host, self.saddr) + assert isinstance(server.local_bind_host, str) + assert server.local_bind_host == self.saddr def test_local_bind_address(self): - """ Test local_bind_address property """ + """Test local_bind_address property""" s = socket.socket() s.bind(('localhost', 0)) addr, port = s.getsockname() @@ -902,21 +958,23 @@ def test_local_bind_address(self): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ) as server: - self.assertIsInstance(server.local_bind_address, tuple) - self.assertTupleEqual(server.local_bind_address, (addr, port)) + assert isinstance(server.local_bind_address, tuple) + assert server.local_bind_address == (addr, port) def test_local_bind_ports(self): - """ Test local_bind_ports property """ + """Test local_bind_ports property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[ + (self.eaddr, self.eport), + (self.saddr, self.sport), + ], logger=self.log, ) as server: - self.assertIsInstance(server.local_bind_ports, list) - with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): + assert isinstance(server.local_bind_ports, list) + with pytest.raises(sshtunnel.BaseSSHTunnelForwarderError): self.log.info(server.local_bind_port) # Single bind should still produce a 1 element list @@ -927,47 +985,48 @@ def test_local_bind_ports(self): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ) as server: - self.assertIsInstance(server.local_bind_ports, list) + assert isinstance(server.local_bind_ports, list) def test_local_bind_hosts(self): - """ Test local_bind_hosts property """ + """Test local_bind_hosts property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, local_bind_addresses=[(self.saddr, 0)] * 2, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[ + (self.eaddr, self.eport), + (self.saddr, self.sport), + ], logger=self.log, ) as server: - self.assertIsInstance(server.local_bind_hosts, list) - self.assertListEqual(server.local_bind_hosts, - [self.saddr] * 2) - with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): + assert isinstance(server.local_bind_hosts, list) + assert server.local_bind_hosts == ([self.saddr] * 2) + with pytest.raises(sshtunnel.BaseSSHTunnelForwarderError): self.log.info(server.local_bind_host) def test_local_bind_addresses(self): - """ Test local_bind_addresses property """ + """Test local_bind_addresses property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, local_bind_addresses=[(self.saddr, 0)] * 2, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[ + (self.eaddr, self.eport), + (self.saddr, self.sport), + ], logger=self.log, ) as server: - self.assertIsInstance(server.local_bind_addresses, list) - self.assertListEqual(server.local_bind_addresses, - list(zip([self.saddr] * 2, - server.local_bind_ports))) - with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): + assert isinstance(server.local_bind_addresses, list) + assert server.local_bind_addresses == list( + zip([self.saddr] * 2, server.local_bind_ports) + ) + with pytest.raises(sshtunnel.BaseSSHTunnelForwarderError): self.log.info(server.local_bind_address) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_check_tunnels(self): - """ Test method checking if tunnels are up """ + """Test method checking if tunnels are up""" remote_address = (self.eaddr, self.eport) with self._test_server( (self.saddr, self.sport), @@ -977,77 +1036,97 @@ def test_check_tunnels(self): logger=self.log, skip_tunnel_checkup=False, ) as server: - self.assertIn('Tunnel to {0} is UP'.format(remote_address), - self.sshtunnel_log_messages['debug']) + assert ( + 'Tunnel to {0} is UP'.format(remote_address) + in self.sshtunnel_log_messages['debug'] + ) server.check_tunnels() - self.assertIn('Tunnel to {0} is DOWN'.format(remote_address), - self.sshtunnel_log_messages['debug']) + assert ( + 'Tunnel to {0} is DOWN'.format(remote_address) + in self.sshtunnel_log_messages['debug'] + ) # Calling local_is_up() should also return the same server.skip_tunnel_checkup = True server.local_is_up((self.saddr, self.sport)) - self.assertIn('Tunnel to {0} is DOWN'.format(remote_address), - self.sshtunnel_log_messages['debug']) + assert ( + 'Tunnel to {0} is DOWN'.format(remote_address) + in self.sshtunnel_log_messages['debug'] + ) - self.assertFalse(server.local_is_up("not a valid address")) - self.assertIn('Target must be a tuple (IP, port), where IP ' - 'is a string (i.e. "192.168.0.1") and port is ' - 'an integer (i.e. 40000). Alternatively ' - 'target can be a valid UNIX domain socket.', - self.sshtunnel_log_messages['warning']) + assert not server.local_is_up('not a valid address') + assert ( + 'Target must be a tuple (IP, port), where IP ' + 'is a string (i.e. "192.168.0.1") and port is ' + 'an integer (i.e. 40000). Alternatively ' + 'target can be a valid UNIX domain socket.' + in self.sshtunnel_log_messages['warning'] + ) @mock.patch('sshtunnel.input_', return_value=linesep) def test_cli_main_exits_when_pressing_enter(self, input): - """ Test that _cli_main() function quits when Enter is pressed """ + """Test that _cli_main() function quits when Enter is pressed""" self.start_echo_and_ssh_server() - sshtunnel._cli_main(args=[self.saddr, - '-U', SSH_USERNAME, - '-P', SSH_PASSWORD, - '-p', str(self.sport), - '-R', '{0}:{1}'.format(self.eaddr, - self.eport), - '-c', '', - '-n'], - host_pkey_directories=[]) + sshtunnel._cli_main( + args=[ + self.saddr, + '-U', + SSH_USERNAME, + '-P', + SSH_PASSWORD, + '-p', + str(self.sport), + '-R', + '{0}:{1}'.format(self.eaddr, self.eport), + '-c', + '', + '-n', + ], + host_pkey_directories=[], + ) self.stop_echo_and_ssh_server() - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_read_private_key_file(self): - """ Test that an encrypted private key can be opened """ + """Test that an encrypted private key can be opened""" encr_pkey = get_test_data_path(ENCRYPTED_PKEY_FILE) pkey = sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - pkey_password='sshtunnel', - logger=self.log + encr_pkey, pkey_password='sshtunnel', logger=self.log ) _pkey = paramiko.RSAKey.from_private_key_file( get_test_data_path(PKEY_FILE) ) - self.assertEqual(pkey, _pkey) + assert pkey == _pkey # Using a wrong password returns None - self.assertIsNone(sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - pkey_password='bad password', - logger=self.log - )) - self.assertIn("Private key file ({0}) could not be loaded as type " - "{1} or bad password" - .format(encr_pkey, type(_pkey)), - self.sshtunnel_log_messages['debug']) + assert ( + sshtunnel.SSHTunnelForwarder.read_private_key_file( + encr_pkey, pkey_password='bad password', logger=self.log + ) + is None + ) + assert ( + 'Private key file ({0}) could not be loaded as type ' + '{1} or bad password'.format(encr_pkey, type(_pkey)) + in self.sshtunnel_log_messages['debug'] + ) # Using no password on an encrypted key returns None - self.assertIsNone(sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - logger=self.log - )) - self.assertIn('Password is required for key {0}'.format(encr_pkey), - self.sshtunnel_log_messages['error']) - - @unittest.skipIf(os.name != 'posix', - reason="UNIX sockets not supported on this platform") + assert ( + sshtunnel.SSHTunnelForwarder.read_private_key_file( + encr_pkey, logger=self.log + ) + is None + ) + assert ( + 'Password is required for key {0}'.format(encr_pkey) + in self.sshtunnel_log_messages['error'] + ) + def test_unix_domains(self): - """ Test use of UNIX domain sockets in local binds """ + """Test use of UNIX domain sockets in local binds""" + + if os.name != 'posix': + pytest.skip('UNIX sockets not supported on this platform') + with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -1056,38 +1135,34 @@ def test_unix_domains(self): local_bind_address=TEST_UNIX_SOCKET, logger=self.log, ) as server: - self.assertEqual(server.local_bind_address, TEST_UNIX_SOCKET) + assert server.local_bind_address == TEST_UNIX_SOCKET - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_tracing_logging(self): """ Test that Tracing mode may be enabled for more fine-grained logs """ - logger = sshtunnel.create_logger(logger=self.log, - loglevel='TRACE') + self.log = sshtunnel.create_logger(logger=self.log, loglevel='TRACE') with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - logger=logger, + logger=self.log, ) as server: - server.logger = sshtunnel.create_logger(logger=server.logger, - loglevel='TRACE') + server.logger = sshtunnel.create_logger( + logger=server.logger, loglevel='TRACE' + ) message = get_random_string(100).encode() # Windows raises WinError 10049 if trying to connect to 0.0.0.0 s = socket.create_connection(('127.0.0.1', server.local_bind_port)) s.send(message) s.recv(100) - s.close + s.close() log = 'send to {0}'.format((self.eaddr, self.eport)) - self.assertTrue(any(log in msg for msg in - self.sshtunnel_log_messages['trace'])) + assert any(log in msg for msg in self.sshtunnel_log_messages['trace']) # set loglevel back to the original value - logger = sshtunnel.create_logger(logger=self.log, - loglevel='DEBUG') + self.log = sshtunnel.create_logger(logger=self.log, loglevel='DEBUG') def test_tunnel_bindings_contain_active_tunnels(self): """ @@ -1099,20 +1174,24 @@ def test_tunnel_bindings_contain_active_tunnels(self): (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_addresses=[(self.eaddr, remote_ports[0]), - (self.eaddr, remote_ports[1])], - local_bind_addresses=[('127.0.0.1', local_ports[0]), - ('127.0.0.1', local_ports[1])], + remote_bind_addresses=[ + (self.eaddr, remote_ports[0]), + (self.eaddr, remote_ports[1]), + ], + local_bind_addresses=[ + ('127.0.0.1', local_ports[0]), + ('127.0.0.1', local_ports[1]), + ], skip_tunnel_checkup=False, ) as server: - self.assertListEqual(server.local_bind_ports, local_ports) - self.assertTupleEqual( - server.tunnel_bindings[(self.eaddr, remote_ports[0])], - ('127.0.0.1', local_ports[0]) + assert server.local_bind_ports == local_ports + assert server.tunnel_bindings[(self.eaddr, remote_ports[0])] == ( + '127.0.0.1', + local_ports[0], ) - self.assertTupleEqual( - server.tunnel_bindings[(self.eaddr, remote_ports[1])], - ('127.0.0.1', local_ports[1]) + assert server.tunnel_bindings[(self.eaddr, remote_ports[1])] == ( + '127.0.0.1', + local_ports[1], ) def check_make_ssh_forward_server_sets_daemon(self, case): @@ -1131,7 +1210,7 @@ def check_make_ssh_forward_server_sets_daemon(self, case): tunnel.daemon_forward_servers = case tunnel.start() for server in tunnel._server_list: - self.assertEqual(server.daemon_threads, case) + assert server.daemon_threads == case finally: tunnel.stop() @@ -1147,20 +1226,22 @@ def test_make_ssh_forward_server_sets_daemon_false(self): """ self.check_make_ssh_forward_server_sets_daemon(False) - def test_get_keys(self): - """ Test loading keys from the paramiko Agent """ + def test_get_keys(self, tmp_path): + """Test loading keys from the paramiko Agent""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), local_bind_address=('', self.randomize_eport()), - logger=self.log + logger=self.log, ) as server: keys = server.get_keys(logger=self.log) - self.assertIsInstance(keys, list) - self.assertFalse(any('keys loaded from agent' in msg for msg in - self.sshtunnel_log_messages['info'])) + assert isinstance(keys, list) + assert not any( + 'keys loaded from agent' in msg + for msg in self.sshtunnel_log_messages['info'] + ) with self._test_server( (self.saddr, self.sport), @@ -1168,120 +1249,133 @@ def test_get_keys(self): ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), local_bind_address=('', self.randomize_eport()), - logger=self.log + logger=self.log, ) as server: keys = server.get_keys(logger=self.log, allow_agent=True) - self.assertIsInstance(keys, list) - self.assertTrue(any('keys loaded from agent' in msg for msg in - self.sshtunnel_log_messages['info'])) + assert isinstance(keys, list) + assert any( + 'keys loaded from agent' in msg + for msg in self.sshtunnel_log_messages['info'] + ) - tmp_dir = tempfile.mkdtemp() - shutil.copy(get_test_data_path(PKEY_FILE), - os.path.join(tmp_dir, 'id_rsa')) + shutil.copy(get_test_data_path(PKEY_FILE), str(tmp_path / 'id_rsa')) keys = sshtunnel.SSHTunnelForwarder.get_keys( self.log, - host_pkey_directories=[tmp_dir, ] + host_pkey_directories=[ + str(tmp_path), + ], ) - self.assertIsInstance(keys, list) - self.assertTrue( - any('1 key(s) loaded' in msg - for msg in self.sshtunnel_log_messages['info']) + assert isinstance(keys, list) + assert any( + '1 key(s) loaded' in msg + for msg in self.sshtunnel_log_messages['info'] ) - shutil.rmtree(tmp_dir) -class AuxiliaryTest(unittest.TestCase): - """ Set of tests that do not need the mock SSH server or logger """ +class TestAuxiliary: + """Set of tests that do not need the mock SSH server or logger""" + + def _test_parser(self, parser): + assert parser['ssh_address'] == '10.10.10.10' + assert parser['ssh_username'] == getpass.getuser() + assert parser['ssh_port'] == 22 + assert parser['ssh_password'] == SSH_PASSWORD + assert parser['remote_bind_addresses'] == [ + ('10.0.0.1', 8080), + ('10.0.0.2', 8080), + ] + assert parser['local_bind_addresses'] == [('', 8081), ('', 8082)] + assert parser['ssh_host_key'] == str(SSH_DSS) + assert parser['ssh_private_key'] == __file__ + assert parser['ssh_private_key_password'] == SSH_PASSWORD + assert parser['threaded'] + assert parser['verbose'] == 3 + assert parser['ssh_proxy'] == ('10.0.0.2', 22) + assert parser['ssh_config_file'] == 'ssh_config' + assert parser['compression'] + assert not parser['allow_agent'] def test_parse_arguments_short(self): - """ Test CLI argument parsing with short parameter names """ - args = ['10.10.10.10', # ssh_address - '-U={0}'.format(getpass.getuser()), # GW username - '-p=22', # GW SSH port - '-P={0}'.format(SSH_PASSWORD), # GW password - '-R', '10.0.0.1:8080', '10.0.0.2:8080', # remote bind list - '-L', ':8081', ':8082', # local bind list - '-k={0}'.format(SSH_DSS), # hostkey - '-K={0}'.format(__file__), # pkey file - '-S={0}'.format(SSH_PASSWORD), # pkey password - '-t', # concurrent connections (threaded) - '-vvv', # triple verbosity - '-x=10.0.0.2:', # proxy address - '-c=ssh_config', # ssh configuration file - '-z', # request compression - '-n', # disable SSH agent key lookup - ] + """Test CLI argument parsing with short parameter names""" + args = [ + '10.10.10.10', # ssh_address + '-U={0}'.format(getpass.getuser()), # GW username + '-p=22', # GW SSH port + '-P={0}'.format(SSH_PASSWORD), # GW password + '-R', + '10.0.0.1:8080', + '10.0.0.2:8080', # remote bind list + '-L', + ':8081', + ':8082', # local bind list + '-k={0}'.format(SSH_DSS), # hostkey + '-K={0}'.format(__file__), # pkey file + '-S={0}'.format(SSH_PASSWORD), # pkey password + '-t', # concurrent connections (threaded) + '-vvv', # triple verbosity + '-x=10.0.0.2:', # proxy address + '-c=ssh_config', # ssh configuration file + '-z', # request compression + '-n', # disable SSH agent key lookup + ] parser = sshtunnel._parse_arguments(args) self._test_parser(parser) with capture_stdout_stderr(): # silence stderr # First argument is mandatory - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): parser = sshtunnel._parse_arguments(args[1:]) # -R argument is mandatory - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): parser = sshtunnel._parse_arguments(args[:4] + args[5:]) def test_parse_arguments_long(self): - """ Test CLI argument parsing with long parameter names """ + """Test CLI argument parsing with long parameter names""" parser = sshtunnel._parse_arguments( - ['10.10.10.10', # ssh_address - '--username={0}'.format(getpass.getuser()), # GW username - '--server_port=22', # GW SSH port - '--password={0}'.format(SSH_PASSWORD), # GW password - '--remote_bind_address', '10.0.0.1:8080', '10.0.0.2:8080', - '--local_bind_address', ':8081', ':8082', # local bind list - '--ssh_host_key={0}'.format(SSH_DSS), # hostkey - '--private_key_file={0}'.format(__file__), # pkey file - '--private_key_password={0}'.format(SSH_PASSWORD), - '--threaded', # concurrent connections (threaded) - '--verbose', '--verbose', '--verbose', # triple verbosity - '--proxy', '10.0.0.2:22', # proxy address - '--config', 'ssh_config', # ssh configuration file - '--compress', # request compression - '--noagent', # disable SSH agent key lookup - ] + [ + '10.10.10.10', # ssh_address + '--username={0}'.format(getpass.getuser()), # GW username + '--server_port=22', # GW SSH port + '--password={0}'.format(SSH_PASSWORD), # GW password + '--remote_bind_address', + '10.0.0.1:8080', + '10.0.0.2:8080', + '--local_bind_address', + ':8081', + ':8082', # local bind list + '--ssh_host_key={0}'.format(SSH_DSS), # hostkey + '--private_key_file={0}'.format(__file__), # pkey file + '--private_key_password={0}'.format(SSH_PASSWORD), + '--threaded', # concurrent connections (threaded) + '--verbose', + '--verbose', + '--verbose', # triple verbosity + '--proxy', + '10.0.0.2:22', # proxy address + '--config', + 'ssh_config', # ssh configuration file + '--compress', # request compression + '--noagent', # disable SSH agent key lookup + ] ) self._test_parser(parser) - def _test_parser(self, parser): - self.assertEqual(parser['ssh_address'], '10.10.10.10') - self.assertEqual(parser['ssh_username'], getpass.getuser()) - self.assertEqual(parser['ssh_port'], 22) - self.assertEqual(parser['ssh_password'], SSH_PASSWORD) - self.assertListEqual(parser['remote_bind_addresses'], - [('10.0.0.1', 8080), ('10.0.0.2', 8080)]) - self.assertListEqual(parser['local_bind_addresses'], - [('', 8081), ('', 8082)]) - self.assertEqual(parser['ssh_host_key'], str(SSH_DSS)) - self.assertEqual(parser['ssh_private_key'], __file__) - self.assertEqual(parser['ssh_private_key_password'], SSH_PASSWORD) - self.assertTrue(parser['threaded']) - self.assertEqual(parser['verbose'], 3) - self.assertEqual(parser['ssh_proxy'], ('10.0.0.2', 22)) - self.assertEqual(parser['ssh_config_file'], 'ssh_config') - self.assertTrue(parser['compression']) - self.assertFalse(parser['allow_agent']) - def test_bindlist(self): """ Test that _bindlist enforces IP:PORT format for local and remote binds """ - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1:8080'), - ('10.0.0.1', 8080)) + assert sshtunnel._bindlist('10.0.0.1:8080') == ('10.0.0.1', 8080) # Missing port in tuple is filled with port 22 - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1:'), - ('10.0.0.1', 22)) - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1'), - ('10.0.0.1', 22)) - with self.assertRaises(argparse.ArgumentTypeError): + assert sshtunnel._bindlist('10.0.0.1:') == ('10.0.0.1', 22) + assert sshtunnel._bindlist('10.0.0.1') == ('10.0.0.1', 22) + with pytest.raises(argparse.ArgumentTypeError): sshtunnel._bindlist('10022:10.0.0.1:22') - with self.assertRaises(argparse.ArgumentTypeError): + with pytest.raises(argparse.ArgumentTypeError): sshtunnel._bindlist(':') def test_raise_fwd_ext(self): - """ Test that we can silence the exceptions on sshtunnel creation """ + """Test that we can silence the exceptions on sshtunnel creation""" server = open_tunnel( '10.10.10.10', ssh_username=SSH_USERNAME, @@ -1293,110 +1387,126 @@ def test_raise_fwd_ext(self): server._raise(sshtunnel.BaseSSHTunnelForwarderError, 'test') server._raise_fwd_exc = True # now exceptions are not silenced - with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): + with pytest.raises(sshtunnel.BaseSSHTunnelForwarderError): server._raise(sshtunnel.BaseSSHTunnelForwarderError, 'test') def test_show_running_version(self): - """ Test that _cli_main() function quits when Enter is pressed """ + """Test that _cli_main() function quits when Enter is pressed""" with capture_stdout_stderr() as (out, err): - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): sshtunnel._cli_main(args=['-V']) if sys.version_info < (3, 4): version = err.getvalue().split()[-1] else: version = out.getvalue().split()[-1] - self.assertEqual(version, - sshtunnel.__version__) + assert version == sshtunnel.__version__ def test_remove_none_values(self): - """ Test removing keys from a dict where values are None """ + """Test removing keys from a dict where values are None""" test_dict = {'key1': 1, 'key2': None, 'key3': 3, 'key4': 0} sshtunnel._remove_none_values(test_dict) - self.assertDictEqual(test_dict, - {'key1': 1, 'key3': 3, 'key4': 0}) + assert test_dict == {'key1': 1, 'key3': 3, 'key4': 0} def test_read_ssh_config(self): - """ Test that we can gather host information from a config file """ - (ssh_hostname, - ssh_username, - ssh_private_key, - ssh_port, - ssh_proxy, - compression) = sshtunnel.SSHTunnelForwarder._read_ssh_config( - 'test', - get_test_data_path(TEST_CONFIG_FILE), + """Test that we can gather host information from a config file""" + ( + ssh_hostname, + ssh_username, + ssh_private_key, + ssh_port, + ssh_proxy, + compression, + ) = sshtunnel.SSHTunnelForwarder._read_ssh_config( + 'test', + get_test_data_path(TEST_CONFIG_FILE), ) - self.assertEqual(ssh_hostname, 'test') - self.assertEqual(ssh_username, 'test') - self.assertEqual(PKEY_FILE, ssh_private_key) - self.assertEqual(ssh_port, 22) # fallback value - self.assertListEqual(ssh_proxy.cmd[-2:], ['test:22', 'sshproxy']) - self.assertTrue(compression) + assert ssh_hostname == 'test' + assert ssh_username == 'test' + assert PKEY_FILE == ssh_private_key + assert ssh_port == 22 # fallback value + assert ssh_proxy.cmd[-2:] == ['test:22', 'sshproxy'] + assert compression # passed parameters are not overriden by config - (ssh_hostname, - ssh_username, - ssh_private_key, - ssh_port, - ssh_proxy, - compression) = sshtunnel.SSHTunnelForwarder._read_ssh_config( - 'other', - get_test_data_path(TEST_CONFIG_FILE), - compression=False + ( + ssh_hostname, + ssh_username, + ssh_private_key, + ssh_port, + ssh_proxy, + compression, + ) = sshtunnel.SSHTunnelForwarder._read_ssh_config( + 'other', get_test_data_path(TEST_CONFIG_FILE), compression=False ) - self.assertEqual(ssh_hostname, '10.0.0.1') - self.assertEqual(ssh_port, 222) - self.assertFalse(compression) + assert ssh_hostname == '10.0.0.1' + assert ssh_port == 222 + assert not compression def test_str(self): server = open_tunnel( 'test', - ssh_private_key=get_test_data_path(PKEY_FILE), + ssh_pkey=get_test_data_path(PKEY_FILE), remote_bind_address=('10.0.0.1', 8080), ) _str = str(server).split(linesep) - self.assertEqual(repr(server), str(server)) - self.assertIn('ssh gateway: test:22', _str) - self.assertIn('proxy: no', _str) - self.assertIn('username: {0}'.format(getpass.getuser()), _str) - self.assertIn('status: not started', _str) + assert repr(server) == str(server) + assert 'ssh gateway: test:22' in _str + assert 'proxy: no' in _str + assert 'username: {0}'.format(getpass.getuser()) in _str + assert 'status: not started' in _str def test_process_deprecations(self): - """ Test processing deprecated API attributes """ - kwargs = {'ssh_host': '10.0.0.1', - 'ssh_address': '10.0.0.1', - 'ssh_private_key': 'testrsa.key', - 'raise_exception_if_any_forwarder_have_a_problem': True} + """Test processing deprecated API attributes""" + kwargs = { + 'ssh_host': '10.0.0.1', + 'ssh_address': '10.0.0.1', + 'ssh_private_key': 'testrsa.key', + 'raise_exception_if_any_forwarder_have_a_problem': True, + } for item in kwargs: - self.assertEqual(kwargs[item], - sshtunnel.SSHTunnelForwarder._process_deprecated( - None, - item, - kwargs.copy() - )) + with pytest.warns( + DeprecationWarning, + match="'{0}' is DEPRECATED use '.+' instead".format(item) + ): + assert kwargs[ + item + ] == sshtunnel.SSHTunnelForwarder._process_deprecated( + None, item, kwargs.copy() + ) # use both deprecated and not None new attribute should raise exception for item in kwargs: - with self.assertRaises(ValueError): - sshtunnel.SSHTunnelForwarder._process_deprecated('some value', - item, - kwargs.copy()) + with warnings.catch_warnings( + category=DeprecationWarning + ), pytest.raises( + ValueError, match="You can't use both '.+' and '.+'" + ): + warnings.simplefilter("ignore") + sshtunnel.SSHTunnelForwarder._process_deprecated( + 'some value', item, kwargs.copy() + ) # deprecated attribute not in deprecation list should raise exception - with self.assertRaises(ValueError): - sshtunnel.SSHTunnelForwarder._process_deprecated('some value', - 'item', - kwargs.copy()) + with warnings.catch_warnings( + category=DeprecationWarning + ), pytest.raises( + ValueError, match="item not included in deprecations list" + ): + warnings.simplefilter("ignore") + sshtunnel.SSHTunnelForwarder._process_deprecated( + 'some value', 'item', kwargs.copy() + ) def test_check_address(self): - """ Test that an exception is raised with incorrect bind addresses """ - address_list = [('10.0.0.1', 10000), - ('10.0.0.1', 10001)] + """Test that an exception is raised with incorrect bind addresses""" + address_list: List[Union[Tuple, str]] = [ + ('10.0.0.1', 10000), ('10.0.0.1', 10001) + ] if os.name == 'posix': # UNIX sockets supported by the platform address_list.append('/tmp/unix-socket') # UNIX sockets not supported on remote addresses - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): sshtunnel.check_addresses(address_list, is_remote=True) - self.assertIsNone(sshtunnel.check_addresses(address_list)) - with self.assertRaises(ValueError): + assert sshtunnel.check_addresses(address_list) is None + with pytest.raises(ValueError): sshtunnel.check_address('this is not valid') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): sshtunnel.check_address(-1) # that's not valid either diff --git a/tox.ini b/tox.ini index 3baf2bb6..0aaa4a2f 100644 --- a/tox.ini +++ b/tox.ini @@ -6,13 +6,14 @@ deps = paramiko -r{toxinidir}/tests/requirements.txt commands = - py.test tests \ - --showlocals \ - --cov sshtunnel \ - --cov-report=term \ - --cov-report=html \ - --durations=10 \ - -n4 -W ignore::DeprecationWarning + pytest tests \ + --showlocals \ + --durations=10 \ + -n auto \ + --cov=sshtunnel \ + --cov-report=html:test_results/coverage.html \ + --cov-report=term \ + --junit-xml=test_results/report.xml [testenv:docs] changedir = docs From b62aa759b16e6ef8d0451f84da14546951dc929d Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 20:25:40 -0500 Subject: [PATCH 02/25] backwards-compatible warning catching --- tests/test_forwarder.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 23bb1e6a..57f89e83 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1475,22 +1475,18 @@ def test_process_deprecations(self): ) # use both deprecated and not None new attribute should raise exception for item in kwargs: - with warnings.catch_warnings( - category=DeprecationWarning - ), pytest.raises( + with warnings.catch_warnings(), pytest.raises( ValueError, match="You can't use both '.+' and '.+'" ): - warnings.simplefilter("ignore") + warnings.simplefilter("ignore", category=DeprecationWarning) sshtunnel.SSHTunnelForwarder._process_deprecated( 'some value', item, kwargs.copy() ) # deprecated attribute not in deprecation list should raise exception - with warnings.catch_warnings( - category=DeprecationWarning - ), pytest.raises( + with warnings.catch_warnings(), pytest.raises( ValueError, match="item not included in deprecations list" ): - warnings.simplefilter("ignore") + warnings.simplefilter("ignore", category=DeprecationWarning) sshtunnel.SSHTunnelForwarder._process_deprecated( 'some value', 'item', kwargs.copy() ) From 772e255fd68d25e647bcdd284983cf1635de4caa Mon Sep 17 00:00:00 2001 From: Ramona T Date: Fri, 6 Mar 2026 20:42:39 -0500 Subject: [PATCH 03/25] break up test_check_address and improve coverage --- tests/test_forwarder.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 57f89e83..ecb7b334 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1491,18 +1491,28 @@ def test_process_deprecations(self): 'some value', 'item', kwargs.copy() ) - def test_check_address(self): - """Test that an exception is raised with incorrect bind addresses""" + def test_check_address_incorrect_type(self): + """Test that exception is raised with incorrect bind address type""" + with pytest.raises( + ValueError, + match='ADDRESS is not a tuple, string, or character buffer' + ): + sshtunnel.check_address(-1) + + @pytest.mark.skipif(os.name != 'posix', reason="UNIX sockets not supported by the platform") + def test_check_address_string(self): + """Test remote unix domain socket exception and invalid string exception""" address_list: List[Union[Tuple, str]] = [ - ('10.0.0.1', 10000), ('10.0.0.1', 10001) + ('10.0.0.1', 10000), ('10.0.0.1', 10001), '/tmp/unix-socket' ] - if os.name == 'posix': # UNIX sockets supported by the platform - address_list.append('/tmp/unix-socket') - # UNIX sockets not supported on remote addresses - with pytest.raises(AssertionError): - sshtunnel.check_addresses(address_list, is_remote=True) assert sshtunnel.check_addresses(address_list) is None - with pytest.raises(ValueError): + + # UNIX sockets not supported on remote addresses + with pytest.raises(AssertionError): + sshtunnel.check_addresses(address_list, is_remote=True) + + with pytest.raises( + ValueError, + match='ADDRESS not a valid socket domain socket' + ): sshtunnel.check_address('this is not valid') - with pytest.raises(ValueError): - sshtunnel.check_address(-1) # that's not valid either From 8054679fc54933050f83ec2a7f5751217ab80199 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 09:06:19 -0500 Subject: [PATCH 04/25] use verbose pipenv commands in circleci --- .circleci/config.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 144d0883..6031678a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -29,13 +29,13 @@ jobs: keys: - sshtunnel-py<< parameters.version >>-{{ checksum "sshtunnel.py" }}-{{ checksum "tests/requirements.txt" }}-0 - run: &install - name: Install sshtunnel and build&test dependencies + name: --verbosesshtunnel and build&test dependencies command: | python --version pipenv --version pip --version - pipenv install -e . - pipenv install --dev -r tests/requirements.txt + pipenv install --verbose -e . + pipenv install --verbose --dev -r tests/requirements.txt cat Pipfile.lock environment: - PIPENV_VENV_IN_PROJECT: 1 @@ -73,7 +73,7 @@ jobs: - save_cache: *save_cache - run: name: Installing documentation dependencies - command: pipenv install --dev -r docs/requirements.txt + command: pipenv install --verbose --dev -r docs/requirements.txt - run: name: Build documentation command: pipenv run sphinx-build -WavE -b html docs _build/html @@ -91,7 +91,7 @@ jobs: - save_cache: *save_cache - run: name: Installing syntax checks dependencies - command: pipenv install --dev -r tests/requirements-syntax.txt + command: pipenv install --verbose --dev -r tests/requirements-syntax.txt - run: name: checking MANIFEST.in command: pipenv run check-manifest --ignore tox.ini,tests*,*.yml From ade21be8d7b4d77fd72f8cf2fc035c6613611bef Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 10:22:41 -0500 Subject: [PATCH 05/25] remove type syntax from test_check_address_string --- tests/test_forwarder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index ecb7b334..4e03191b 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1502,7 +1502,7 @@ def test_check_address_incorrect_type(self): @pytest.mark.skipif(os.name != 'posix', reason="UNIX sockets not supported by the platform") def test_check_address_string(self): """Test remote unix domain socket exception and invalid string exception""" - address_list: List[Union[Tuple, str]] = [ + address_list = [ ('10.0.0.1', 10000), ('10.0.0.1', 10001), '/tmp/unix-socket' ] assert sshtunnel.check_addresses(address_list) is None From b119be99430f1f8e5de8d6a5579f5410b086a0b2 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 16:36:32 -0500 Subject: [PATCH 06/25] reformat tests --- pyproject.toml | 2 +- tests/test_forwarder.py | 40 +++++++++++++++++++++------------------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b0471b7f..864b334a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta:__legacy__" \ No newline at end of file +build-backend = "setuptools.build_meta:__legacy__" diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 4e03191b..9248526b 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -317,7 +317,7 @@ def wait_for_thread(self, thread, timeout=THREADS_TIMEOUT, who=None): def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): self.log.debug('forward-server Start') self.ssh_event.wait(THREADS_TIMEOUT) # wait for SSH server's transport - info = "" + info = '' schan = None echo = None try: @@ -327,9 +327,9 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): echo = socket.create_connection((self.eaddr, self.eport)) while self.is_server_working: inputs = [ - obj for obj in [schan, echo] if ( - obj is not None and hasattr(obj, 'fileno') - ) + obj + for obj in [schan, echo] + if (obj is not None and hasattr(obj, 'fileno')) ] if len(inputs) < 2: continue @@ -554,7 +554,7 @@ def test_open_tunnel_block_on_close_deprecation(self): with pytest.warns( DeprecationWarning, match=re.escape( - "You should use either .stop() or .stop(force=True)" + 'You should use either .stop() or .stop(force=True)' ), ): sshtunnel.open_tunnel( @@ -749,7 +749,7 @@ def test_not_setting_password_or_pkey_raises_error(self): 'ssh_host', 'raise_exception_if_any_forwarder_have_a_problem', 'ssh_private_key', - ] + ], ) def test_deprecation_warnings_are_shown(self, deprecated_arg): """ @@ -757,10 +757,9 @@ def test_deprecation_warnings_are_shown(self, deprecated_arg): """ replacement = sshtunnel._DEPRECATIONS[deprecated_arg] - expected_msg = ( - "'{0}' is DEPRECATED " - "use '{1}' instead" - ).format(deprecated_arg, replacement) + expected_msg = "'{0}' is DEPRECATED use '{1}' instead".format( + deprecated_arg, replacement + ) _kwargs = { 'ssh_username': SSH_USERNAME, @@ -1466,7 +1465,7 @@ def test_process_deprecations(self): for item in kwargs: with pytest.warns( DeprecationWarning, - match="'{0}' is DEPRECATED use '.+' instead".format(item) + match="'{0}' is DEPRECATED use '.+' instead".format(item), ): assert kwargs[ item @@ -1478,15 +1477,15 @@ def test_process_deprecations(self): with warnings.catch_warnings(), pytest.raises( ValueError, match="You can't use both '.+' and '.+'" ): - warnings.simplefilter("ignore", category=DeprecationWarning) + warnings.simplefilter('ignore', category=DeprecationWarning) sshtunnel.SSHTunnelForwarder._process_deprecated( 'some value', item, kwargs.copy() ) # deprecated attribute not in deprecation list should raise exception with warnings.catch_warnings(), pytest.raises( - ValueError, match="item not included in deprecations list" + ValueError, match='item not included in deprecations list' ): - warnings.simplefilter("ignore", category=DeprecationWarning) + warnings.simplefilter('ignore', category=DeprecationWarning) sshtunnel.SSHTunnelForwarder._process_deprecated( 'some value', 'item', kwargs.copy() ) @@ -1495,15 +1494,19 @@ def test_check_address_incorrect_type(self): """Test that exception is raised with incorrect bind address type""" with pytest.raises( ValueError, - match='ADDRESS is not a tuple, string, or character buffer' + match='ADDRESS is not a tuple, string, or character buffer', ): sshtunnel.check_address(-1) - @pytest.mark.skipif(os.name != 'posix', reason="UNIX sockets not supported by the platform") + @pytest.mark.skipif( + os.name != 'posix', reason='UNIX sockets not supported by the platform' + ) def test_check_address_string(self): """Test remote unix domain socket exception and invalid string exception""" address_list = [ - ('10.0.0.1', 10000), ('10.0.0.1', 10001), '/tmp/unix-socket' + ('10.0.0.1', 10000), + ('10.0.0.1', 10001), + '/tmp/unix-socket', ] assert sshtunnel.check_addresses(address_list) is None @@ -1512,7 +1515,6 @@ def test_check_address_string(self): sshtunnel.check_addresses(address_list, is_remote=True) with pytest.raises( - ValueError, - match='ADDRESS not a valid socket domain socket' + ValueError, match='ADDRESS not a valid socket domain socket' ): sshtunnel.check_address('this is not valid') From 1146407018021582869c5344f265a509019fdf69 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 16:48:55 -0500 Subject: [PATCH 07/25] get_test_data_path should return str --- tests/test_forwarder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 9248526b..8c7a4088 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -52,7 +52,7 @@ def get_random_string(length=12): def get_test_data_path(x): - return path.join(path.abspath(path.dirname(__file__)), x) + return str(path.join(path.abspath(path.dirname(__file__)), x)) @contextmanager From 597cac2df0825f336638eed40c85deb9c2aeae2f Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 16:49:17 -0500 Subject: [PATCH 08/25] ensure out is declared in capture_stdout_stderr --- tests/test_forwarder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 8c7a4088..f46acda9 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -58,8 +58,8 @@ def get_test_data_path(x): @contextmanager def capture_stdout_stderr(): (old_out, old_err) = (sys.stdout, sys.stderr) + out = [StringIO(), StringIO()] try: - out = [StringIO(), StringIO()] (sys.stdout, sys.stderr) = out yield out finally: From ff6ea61c4e397ba19b5ed6575336bb8e20f9d3e6 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 16:49:28 -0500 Subject: [PATCH 09/25] mark _test_parser as staticmethod --- tests/test_forwarder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index f46acda9..a0901818 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1275,7 +1275,8 @@ def test_get_keys(self, tmp_path): class TestAuxiliary: """Set of tests that do not need the mock SSH server or logger""" - def _test_parser(self, parser): + @staticmethod + def _test_parser(parser): assert parser['ssh_address'] == '10.10.10.10' assert parser['ssh_username'] == getpass.getuser() assert parser['ssh_port'] == 22 From 187a2da52a794d196da11287b98a110454f22655 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 16:50:24 -0500 Subject: [PATCH 10/25] remove unused imports --- tests/test_forwarder.py | 75 +++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index a0901818..5e1e781d 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -15,7 +15,6 @@ from contextlib import contextmanager from functools import partial from os import linesep, path -from typing import List, Tuple, Union import mock import paramiko @@ -315,46 +314,64 @@ def wait_for_thread(self, thread, timeout=THREADS_TIMEOUT, who=None): thread.join(timeout) def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): - self.log.debug('forward-server Start') - self.ssh_event.wait(THREADS_TIMEOUT) # wait for SSH server's transport - info = '' schan = None echo = None + info = 'forward-server schan <> echo' + + self.log.debug('forward-server Start') + # wait for SSH server's transport + self.ssh_event.wait(THREADS_TIMEOUT) + try: schan = self.ts.accept(timeout=timeout) - info = 'forward-server schan <> echo' - self.log.info(info + ' accept()') - echo = socket.create_connection((self.eaddr, self.eport)) + if schan is None: + self.log.error( + '%s: Failed to accept SSH channel (timeout)', info + ) + return + + echo = socket.create_connection( + (self.eaddr, self.eport), timeout=timeout + ) + self.log.info('%s established', info) + while self.is_server_working: - inputs = [ - obj - for obj in [schan, echo] - if (obj is not None and hasattr(obj, 'fileno')) - ] - if len(inputs) < 2: - continue - rqst, _, _ = select.select(inputs, [], [], timeout) + # On Windows, select.select only accepts objects with a .fileno() + try: + r_list = [obj for obj in [schan, echo] if obj is not None] + if not r_list: + break + + rqst, _, _ = select.select(r_list, [], [], timeout) + except (ValueError, TypeError) as e: + self.log.error('%s: Select error: %s', info, e) + break + if schan in rqst: data = schan.recv(1024) - self.log.debug('{0} -->: {1}'.format(info, repr(data))) - echo.send(data) - if len(data) == 0: + if not data: # Connection closed break + self.log.debug('%s -->: %s', info, repr(data)) + echo.sendall(data) + if echo in rqst: data = echo.recv(1024) - self.log.debug('{0} <--: {1}'.format(info, repr(data))) - schan.send(data) - if len(data) == 0: + if not data: # Connection closed break - self.log.info('<<< forward-server received STOP signal') - except socket.error: - self.log.critical('{0} sending RST'.format(info)) + self.log.debug('%s <--: %s', info, repr(data)) + schan.sendall(data) + + except (socket.error, Exception) as e: + self.log.error('%s: Error during forwarding: %r', info, e) + finally: - if schan: - self.log.debug('{0} closing connection...'.format(info)) - schan.close() - echo.close() - self.log.debug('{0} connection closed.'.format(info)) + for obj in [schan, echo]: + if obj: + try: + obj.close() + except paramiko.SSHException: + pass + self.log.debug('%s connections closed.', info) def _run_ssh_server(self): self.log.info('ssh-server Start') From a3a1a04c3afb3adc40e7a3d37bbdc9c453c6586b Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 16:51:25 -0500 Subject: [PATCH 11/25] add paramiko to test dependencies --- tests/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/requirements.txt b/tests/requirements.txt index c1d3e7d3..98a55268 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,5 +1,6 @@ coveralls mock +paramiko pytest>=4 pytest-cov pytest-xdist From cb1e7be417945e02e5e44f29d4049abb08250b95 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 21:09:04 -0500 Subject: [PATCH 12/25] test unix string on unsupported platform exception --- tests/test_forwarder.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 5e1e781d..6ee10c97 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1536,3 +1536,13 @@ def test_check_address_string(self): ValueError, match='ADDRESS not a valid socket domain socket' ): sshtunnel.check_address('this is not valid') + + @pytest.mark.skipif( + os.name == 'posix', reason='UNIX sockets must not be supported by the platform' + ) + def test_check_address_string_not_supported(self): + """Test unix domain socket exception on unsupported platform""" + with pytest.raises( + ValueError, match='Platform does not support UNIX domain sockets' + ): + sshtunnel.check_address('/tmp/unix-socket') From bb95aae84a45fbf422b2877a7c1783c3bca696a9 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 21:48:00 -0500 Subject: [PATCH 13/25] add test for OSError during get_keys --- tests/test_forwarder.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 6ee10c97..b1951713 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -15,6 +15,7 @@ from contextlib import contextmanager from functools import partial from os import linesep, path +from unittest.mock import patch import mock import paramiko @@ -1288,6 +1289,21 @@ def test_get_keys(self, tmp_path): for msg in self.sshtunnel_log_messages['info'] ) + def test_get_keys_check_error(self, tmp_path): + """Test if warning is shown if an OS error occurs while reading keys""" + (tmp_path / "id_rsa").write_text("this file exists") + + with patch('sshtunnel.SSHTunnelForwarder.read_private_key_file') as mock_read: + mock_read.side_effect = OSError() + sshtunnel.SSHTunnelForwarder.get_keys( + logger=self.log, + host_pkey_directories=[str(tmp_path)] + ) + + assert any( + 'Private key file' in msg and 'check error' in msg + for msg in self.sshtunnel_log_messages['warning'] + ) class TestAuxiliary: """Set of tests that do not need the mock SSH server or logger""" From ac49cb5c1533594e2506497f94bb9f28378db9c7 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 21:50:38 -0500 Subject: [PATCH 14/25] flake8 on tests --- tests/test_forwarder.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index b1951713..2ce3c63e 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -337,7 +337,7 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): self.log.info('%s established', info) while self.is_server_working: - # On Windows, select.select only accepts objects with a .fileno() + # select.select only accepts objects with .fileno() on win try: r_list = [obj for obj in [schan, echo] if obj is not None] if not r_list: @@ -1293,7 +1293,9 @@ def test_get_keys_check_error(self, tmp_path): """Test if warning is shown if an OS error occurs while reading keys""" (tmp_path / "id_rsa").write_text("this file exists") - with patch('sshtunnel.SSHTunnelForwarder.read_private_key_file') as mock_read: + with patch( + 'sshtunnel.SSHTunnelForwarder.read_private_key_file' + ) as mock_read: mock_read.side_effect = OSError() sshtunnel.SSHTunnelForwarder.get_keys( logger=self.log, @@ -1305,6 +1307,7 @@ def test_get_keys_check_error(self, tmp_path): for msg in self.sshtunnel_log_messages['warning'] ) + class TestAuxiliary: """Set of tests that do not need the mock SSH server or logger""" @@ -1536,7 +1539,7 @@ def test_check_address_incorrect_type(self): os.name != 'posix', reason='UNIX sockets not supported by the platform' ) def test_check_address_string(self): - """Test remote unix domain socket exception and invalid string exception""" + """Remote unix domain socket exception and invalid string exception""" address_list = [ ('10.0.0.1', 10000), ('10.0.0.1', 10001), @@ -1554,7 +1557,8 @@ def test_check_address_string(self): sshtunnel.check_address('this is not valid') @pytest.mark.skipif( - os.name == 'posix', reason='UNIX sockets must not be supported by the platform' + os.name == 'posix', + reason='UNIX sockets must not be supported by the platform', ) def test_check_address_string_not_supported(self): """Test unix domain socket exception on unsupported platform""" From f34e0f24c09b74cca817774492eb30fde11629ed Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 22:34:06 -0500 Subject: [PATCH 15/25] make test_check_address* os-agnostic --- tests/test_forwarder.py | 107 +++++++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 39 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 2ce3c63e..0121e1a0 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1291,15 +1291,14 @@ def test_get_keys(self, tmp_path): def test_get_keys_check_error(self, tmp_path): """Test if warning is shown if an OS error occurs while reading keys""" - (tmp_path / "id_rsa").write_text("this file exists") + (tmp_path / 'id_rsa').write_text('this file exists') with patch( 'sshtunnel.SSHTunnelForwarder.read_private_key_file' ) as mock_read: mock_read.side_effect = OSError() sshtunnel.SSHTunnelForwarder.get_keys( - logger=self.log, - host_pkey_directories=[str(tmp_path)] + logger=self.log, host_pkey_directories=[str(tmp_path)] ) assert any( @@ -1527,42 +1526,72 @@ def test_process_deprecations(self): 'some value', 'item', kwargs.copy() ) - def test_check_address_incorrect_type(self): - """Test that exception is raised with incorrect bind address type""" - with pytest.raises( - ValueError, - match='ADDRESS is not a tuple, string, or character buffer', - ): - sshtunnel.check_address(-1) - - @pytest.mark.skipif( - os.name != 'posix', reason='UNIX sockets not supported by the platform' - ) - def test_check_address_string(self): - """Remote unix domain socket exception and invalid string exception""" - address_list = [ - ('10.0.0.1', 10000), - ('10.0.0.1', 10001), - '/tmp/unix-socket', - ] - assert sshtunnel.check_addresses(address_list) is None - # UNIX sockets not supported on remote addresses - with pytest.raises(AssertionError): - sshtunnel.check_addresses(address_list, is_remote=True) +@pytest.mark.parametrize( + ('address', 'os_name', 'path_exists', 'expected_error', 'match'), + [ + ( + -1, + 'posix', + False, + ValueError, + 'ADDRESS is not a tuple, string, or character buffer', + ), + ( + 'not/a/path', + 'posix', + False, + ValueError, + 'ADDRESS not a valid socket domain socket', + ), + ( + '/tmp/unix.sock', + 'nt', + True, + ValueError, + 'Platform does not support UNIX domain sockets', + ), + ('/tmp/unix.sock', 'posix', True, None, None), + (('10.0.0.1', 8080), 'posix', True, None, None), + ], +) +def test_check_address_combined( + address, os_name, path_exists, expected_error, match +): + with patch('os.name', os_name), patch( + 'os.path.exists', return_value=path_exists + ), patch('os.access', return_value=path_exists): + if expected_error: + with pytest.raises(expected_error, match=match): + sshtunnel.check_address(address) + else: + # Should not raise any exception + sshtunnel.check_address(address) - with pytest.raises( - ValueError, match='ADDRESS not a valid socket domain socket' - ): - sshtunnel.check_address('this is not valid') - @pytest.mark.skipif( - os.name == 'posix', - reason='UNIX sockets must not be supported by the platform', - ) - def test_check_address_string_not_supported(self): - """Test unix domain socket exception on unsupported platform""" - with pytest.raises( - ValueError, match='Platform does not support UNIX domain sockets' - ): - sshtunnel.check_address('/tmp/unix-socket') +@pytest.mark.parametrize( + ('address_list', 'is_remote', 'expected_error', 'match'), + [ + ([('10.0.0.1', 10000), '/tmp/unix-socket'], False, None, None), + ( + [('10.0.0.1', 10000), '/tmp/unix-socket'], + True, + AssertionError, + 'UNIX domain sockets not allowed', + ), + ([('10.0.0.1', 10000), 123], False, AssertionError, None), + ], +) +def test_check_addresses_combined( + address_list, is_remote, expected_error, match +): + with ( + patch('os.name', 'posix'), + patch('os.path.exists', return_value=True), + patch('os.access', return_value=True), + ): + if expected_error: + with pytest.raises(expected_error, match=match): + sshtunnel.check_addresses(address_list, is_remote=is_remote) + else: + sshtunnel.check_addresses(address_list, is_remote=is_remote) From 737113a9d209637f2c44181e9bdaf535d9cf0b47 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Sat, 7 Mar 2026 22:53:17 -0500 Subject: [PATCH 16/25] Revert "use verbose pipenv commands in circleci" This reverts commit 8054679fc54933050f83ec2a7f5751217ab80199. --- .circleci/config.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 6031678a..144d0883 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -29,13 +29,13 @@ jobs: keys: - sshtunnel-py<< parameters.version >>-{{ checksum "sshtunnel.py" }}-{{ checksum "tests/requirements.txt" }}-0 - run: &install - name: --verbosesshtunnel and build&test dependencies + name: Install sshtunnel and build&test dependencies command: | python --version pipenv --version pip --version - pipenv install --verbose -e . - pipenv install --verbose --dev -r tests/requirements.txt + pipenv install -e . + pipenv install --dev -r tests/requirements.txt cat Pipfile.lock environment: - PIPENV_VENV_IN_PROJECT: 1 @@ -73,7 +73,7 @@ jobs: - save_cache: *save_cache - run: name: Installing documentation dependencies - command: pipenv install --verbose --dev -r docs/requirements.txt + command: pipenv install --dev -r docs/requirements.txt - run: name: Build documentation command: pipenv run sphinx-build -WavE -b html docs _build/html @@ -91,7 +91,7 @@ jobs: - save_cache: *save_cache - run: name: Installing syntax checks dependencies - command: pipenv install --verbose --dev -r tests/requirements-syntax.txt + command: pipenv install --dev -r tests/requirements-syntax.txt - run: name: checking MANIFEST.in command: pipenv run check-manifest --ignore tox.ini,tests*,*.yml From 63b15103cb70cd40f7e772d90ad9c46831eb64a6 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Wed, 11 Mar 2026 22:39:58 -0400 Subject: [PATCH 17/25] backwards-compatible mock import --- tests/test_forwarder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 0121e1a0..fdd4cc3c 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -15,9 +15,12 @@ from contextlib import contextmanager from functools import partial from os import linesep, path -from unittest.mock import patch -import mock +try: + from unittest.mock import MagicMock, patch +except ImportError: + from mock import MagicMock, patch + import paramiko import pytest From 07b8e05dab89d3795ba354ca631281b3d90ee8a2 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Wed, 11 Mar 2026 22:40:11 -0400 Subject: [PATCH 18/25] sort imports in run_docker__* --- e2e_tests/run_docker_e2e_db_tests.py | 14 ++++++++------ e2e_tests/run_docker_e2e_hangs_tests.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/e2e_tests/run_docker_e2e_db_tests.py b/e2e_tests/run_docker_e2e_db_tests.py index b9ea4df6..fea8ee2e 100644 --- a/e2e_tests/run_docker_e2e_db_tests.py +++ b/e2e_tests/run_docker_e2e_db_tests.py @@ -1,14 +1,16 @@ +import logging +import os import select -import traceback import sys -import os -import time -from sshtunnel import SSHTunnelForwarder -import sshtunnel -import logging import threading +import time +import traceback + import paramiko +import sshtunnel +from sshtunnel import SSHTunnelForwarder + sshtunnel.DEFAULT_LOGLEVEL = 1 logging.basicConfig( format='%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/%(lineno)04d@%(module)-10.9s| %(message)s', level=1) diff --git a/e2e_tests/run_docker_e2e_hangs_tests.py b/e2e_tests/run_docker_e2e_hangs_tests.py index 0ec7449e..7abd491d 100644 --- a/e2e_tests/run_docker_e2e_hangs_tests.py +++ b/e2e_tests/run_docker_e2e_hangs_tests.py @@ -1,7 +1,7 @@ import logging -import sshtunnel import os +import sshtunnel if __name__ == '__main__': path = os.path.join(os.path.dirname(__file__), 'run_docker_e2e_db_tests.py') From 61fcaf72336ee09f91fc60a85907070f58fdbcf6 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Wed, 11 Mar 2026 22:43:55 -0400 Subject: [PATCH 19/25] use old-style context manager syntax in test_check_addresses_combined --- tests/test_forwarder.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index fdd4cc3c..919aa074 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1588,11 +1588,9 @@ def test_check_address_combined( def test_check_addresses_combined( address_list, is_remote, expected_error, match ): - with ( - patch('os.name', 'posix'), - patch('os.path.exists', return_value=True), - patch('os.access', return_value=True), - ): + with patch('os.name', 'posix'), \ + patch('os.path.exists', return_value=True), \ + patch('os.access', return_value=True): if expected_error: with pytest.raises(expected_error, match=match): sshtunnel.check_addresses(address_list, is_remote=is_remote) From 273d73a644648f359564c9971a30100e788f465c Mon Sep 17 00:00:00 2001 From: Ramona T Date: Wed, 11 Mar 2026 22:49:28 -0400 Subject: [PATCH 20/25] fix import for newer python versions --- tests/test_forwarder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 919aa074..a003bd98 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -17,9 +17,10 @@ from os import linesep, path try: - from unittest.mock import MagicMock, patch + from unittest import mock + from unittest.mock import patch except ImportError: - from mock import MagicMock, patch + from mock import mock, patch import paramiko import pytest From 45ff4d4b186fb95b42439efab35f303328f2d642 Mon Sep 17 00:00:00 2001 From: Ramona T Date: Wed, 11 Mar 2026 22:50:55 -0400 Subject: [PATCH 21/25] don't install mock if we aren't using it --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 98a55268..5479cd0f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,5 +1,5 @@ coveralls -mock +mock; python_version < '3.3' paramiko pytest>=4 pytest-cov From 944f0459a437da764e03ee53939e63cc415a422b Mon Sep 17 00:00:00 2001 From: Ramona T Date: Wed, 11 Mar 2026 23:00:24 -0400 Subject: [PATCH 22/25] more succinct capture_stdout_stderr --- tests/test_forwarder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index a003bd98..892ec3c0 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -61,15 +61,15 @@ def get_test_data_path(x): @contextmanager def capture_stdout_stderr(): - (old_out, old_err) = (sys.stdout, sys.stderr) - out = [StringIO(), StringIO()] + out, err = StringIO(), StringIO() + old_out, old_err = sys.stdout, sys.stderr try: - (sys.stdout, sys.stderr) = out - yield out + sys.stdout, sys.stderr = out, err + yield [out, err] finally: - (sys.stdout, sys.stderr) = (old_out, old_err) - out[0] = out[0].getvalue() - out[1] = out[1].getvalue() + sys.stdout, sys.stderr = old_out, old_err + out.seek(0) + err.seek(0) # Ensure that ``ssh_config_file is None`` during tests, exceptions are not From 1cd107e6a419f43178182e75f1c584019717b9fe Mon Sep 17 00:00:00 2001 From: Ramona T Date: Wed, 11 Mar 2026 23:04:20 -0400 Subject: [PATCH 23/25] simplify MockLoggingHandler --- tests/test_forwarder.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 892ec3c0..1195c57c 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -118,30 +118,18 @@ class MockLoggingHandler(logging.Handler, object): def __init__(self, *args, **kwargs): self.messages = { - 'debug': [], - 'info': [], - 'warning': [], - 'error': [], - 'critical': [], - 'trace': [], + k: [] for k in [ + 'debug', 'info', 'warning', 'error', 'critical', 'trace' + ] } - super(MockLoggingHandler, self).__init__(*args, **kwargs) + logging.Handler.__init__(self, *args, **kwargs) def emit(self, record): - """Store a message from ``record`` in ``self.messages`` dict.""" - self.acquire() - try: - self.messages[record.levelname.lower()].append(record.getMessage()) - finally: - self.release() + self.messages[record.levelname.lower()].append(record.getMessage()) def reset(self): - self.acquire() - try: - for message_list in self.messages: - self.messages[message_list] = [] - finally: - self.release() + for k in self.messages: + self.messages[k] = [] class NullServer(paramiko.ServerInterface): From db28147cc8785fea580230f10ced401d4438fd0f Mon Sep 17 00:00:00 2001 From: Ramona T Date: Wed, 11 Mar 2026 23:16:01 -0400 Subject: [PATCH 24/25] replace tmp_path with tmpdir for 2.7 compatibility in tests --- tests/test_forwarder.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 1195c57c..ca30d17a 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1235,7 +1235,7 @@ def test_make_ssh_forward_server_sets_daemon_false(self): """ self.check_make_ssh_forward_server_sets_daemon(False) - def test_get_keys(self, tmp_path): + def test_get_keys(self, tmpdir): """Test loading keys from the paramiko Agent""" with self._test_server( (self.saddr, self.sport), @@ -1267,13 +1267,11 @@ def test_get_keys(self, tmp_path): for msg in self.sshtunnel_log_messages['info'] ) - shutil.copy(get_test_data_path(PKEY_FILE), str(tmp_path / 'id_rsa')) + shutil.copy(get_test_data_path(PKEY_FILE), str(tmpdir.join('id_rsa'))) keys = sshtunnel.SSHTunnelForwarder.get_keys( self.log, - host_pkey_directories=[ - str(tmp_path), - ], + host_pkey_directories=[str(tmpdir)], ) assert isinstance(keys, list) assert any( @@ -1281,16 +1279,16 @@ def test_get_keys(self, tmp_path): for msg in self.sshtunnel_log_messages['info'] ) - def test_get_keys_check_error(self, tmp_path): + def test_get_keys_check_error(self, tmpdir): """Test if warning is shown if an OS error occurs while reading keys""" - (tmp_path / 'id_rsa').write_text('this file exists') + tmpdir.join('id_rsa').write('this file exists') with patch( 'sshtunnel.SSHTunnelForwarder.read_private_key_file' ) as mock_read: mock_read.side_effect = OSError() sshtunnel.SSHTunnelForwarder.get_keys( - logger=self.log, host_pkey_directories=[str(tmp_path)] + logger=self.log, host_pkey_directories=[str(tmpdir)] ) assert any( From f9baf6253e5b3da2a708994e11f51da9441d836b Mon Sep 17 00:00:00 2001 From: Ramona T Date: Thu, 12 Mar 2026 13:59:13 -0400 Subject: [PATCH 25/25] sync paramiko dep from setup to test requirements --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 5479cd0f..ebfb16a7 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,6 +1,6 @@ coveralls mock; python_version < '3.3' -paramiko +paramiko>=2.7.2 pytest>=4 pytest-cov pytest-xdist