Skip to content

Commit 7232d85

Browse files
maxballengerMax Ballenger
andauthored
Ensuring disconnect happens after close_channel in TunnelServer (#307)
* Fix issue with TunnelServer disconnect while channel still open * Moving client disconnect() into TunnelServer * Added test for proxy reconnection * Resolves #304 Co-authored-by: Max Ballenger <max@aeva.ai>
1 parent 46e9a3d commit 7232d85

File tree

3 files changed

+56
-8
lines changed

3 files changed

+56
-8
lines changed

pssh/clients/native/single.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,18 @@ def disconnect(self):
173173
except Exception:
174174
pass
175175
self.session = None
176-
self.sock = None
177176
if isinstance(self._proxy_client, SSHClient):
178-
self._proxy_client.disconnect()
177+
# Don't disconnect proxy client here - let the TunnelServer do it at the time that
178+
# _wait_send_receive_lets ends. The cleanup_server call here triggers the TunnelServer
179+
# to stop.
180+
FORWARDER.cleanup_server(self._proxy_client)
181+
182+
# I wanted to clean up all the sockets here to avoid a ResourceWarning from unittest,
183+
# but unfortunately closing this socket here causes a segfault, not sure why yet.
184+
# self.sock.close()
185+
else:
186+
self.sock.close()
187+
self.sock = None
179188

180189
def spawn_send_keepalive(self):
181190
"""Spawns a new greenlet that sends keep alive messages every

pssh/clients/native/tunnel.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,8 @@ def _cleanup_servers_let(self):
9090

9191
def _cleanup_servers(self):
9292
for client in list(self._servers.keys()):
93-
server = self._servers[client]
9493
if client.sock is None or client.sock.closed:
95-
server.stop()
96-
del self._servers[client]
94+
self.cleanup_server(client)
9795

9896
def run(self):
9997
"""Thread runner ensures a non main hub has been created for all subsequent
@@ -118,6 +116,13 @@ def run(self):
118116
exc_info=1)
119117
self.shutdown()
120118

119+
def cleanup_server(self, client):
120+
"""The purpose of this function is for a proxied client to notify the LocalForwarder that it
121+
is shutting down and its corresponding server can also be shut down."""
122+
server = self._servers[client]
123+
server.stop()
124+
del self._servers[client]
125+
121126

122127
class TunnelServer(StreamServer):
123128
"""Local port forwarding server for tunneling connections from remote SSH server.
@@ -165,9 +170,12 @@ def _wait_send_receive_lets(self, source, dest, channel, forward_sock):
165170
try:
166171
joinall((source, dest), raise_error=True)
167172
finally:
168-
logger.debug("Closing channel and forward socket")
173+
# Forward socket does not need to be closed here; StreamServer does it in do_close
174+
logger.debug("Closing channel")
169175
self._client.close_channel(channel)
170-
forward_sock.close()
176+
177+
# Disconnect client here to make sure it happens AFTER close_channel
178+
self._client.disconnect()
171179

172180
def _read_forward_sock(self, forward_sock, channel):
173181
while True:

tests/native/test_tunnel.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import string
2424
import random
2525
import time
26+
import gc
2627

2728
from datetime import datetime
2829
from socket import timeout as socket_timeout
@@ -32,7 +33,7 @@
3233

3334
from pssh.config import HostConfig
3435
from pssh.clients.native import SSHClient, ParallelSSHClient
35-
from pssh.clients.native.tunnel import LocalForwarder, TunnelServer
36+
from pssh.clients.native.tunnel import LocalForwarder, TunnelServer, FORWARDER
3637
from pssh.exceptions import UnknownHostException, \
3738
AuthenticationException, ConnectionErrorException, SessionError, \
3839
HostArgumentException, SFTPError, SFTPIOError, Timeout, SCPError, \
@@ -95,6 +96,36 @@ def test_tunnel_server(self):
9596
self.assertEqual(self.port, client.port)
9697
finally:
9798
remote_server.stop()
99+
100+
# The purpose of this test is to exercise
101+
# https://github.com/ParallelSSH/parallel-ssh/issues/304
102+
def test_tunnel_server_reconn(self):
103+
remote_host = '127.0.0.8'
104+
remote_server = OpenSSHServer(listen_ip=remote_host, port=self.port)
105+
remote_server.start_server()
106+
107+
reconn_n = 20 # Number of reconnect attempts
108+
reconn_delay = 1 # Number of seconds to delay betwen reconnects
109+
try:
110+
for _ in range(reconn_n):
111+
client = SSHClient(
112+
remote_host, port=self.port, pkey=self.user_key,
113+
num_retries=1,
114+
proxy_host=self.proxy_host,
115+
proxy_pkey=self.user_key,
116+
proxy_port=self.proxy_port,
117+
)
118+
output = client.run_command(self.cmd)
119+
_stdout = list(output.stdout)
120+
self.assertListEqual(_stdout, [self.resp])
121+
self.assertEqual(remote_host, client.host)
122+
self.assertEqual(self.port, client.port)
123+
client.disconnect()
124+
FORWARDER._cleanup_servers()
125+
time.sleep(reconn_delay)
126+
gc.collect()
127+
finally:
128+
remote_server.stop()
98129

99130
def test_tunnel_server_same_port(self):
100131
remote_host = '127.0.0.7'

0 commit comments

Comments
 (0)