diff --git a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs index 7d731b9e5..16c0fce15 100644 --- a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs +++ b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs @@ -1,6 +1,9 @@ using System; +using System.IO; +using System.Linq; using System.Net; using System.Net.Sockets; +using System.Threading; using Renci.SshNet.Abstractions; using Renci.SshNet.Common; using Renci.SshNet.Messages.Connection; @@ -16,6 +19,10 @@ internal class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTcpip private Socket _socket; private IForwardedPort _forwardedPort; + private bool doSocks; + private bool doSocks5; + private ManualResetEvent completionWaitHandle; + /// /// Initializes a new instance. /// @@ -69,7 +76,17 @@ public void Bind(IPEndPoint remoteEndpoint, IForwardedPort forwardedPort) _forwardedPort = forwardedPort; _forwardedPort.Closing += ForwardedPort_Closing; - // Try to connect to the socket + if (remoteEndpoint == null) + { + doSocks = true; + SendMessage(new ChannelOpenConfirmationMessage(RemoteChannelNumber, LocalWindowSize, LocalPacketSize, LocalChannelNumber)); + + completionWaitHandle = new ManualResetEvent(false); + completionWaitHandle.WaitOne(); + completionWaitHandle.Dispose(); + } + + // Try to connect to the socket try { _socket = SocketAbstraction.Connect(remoteEndpoint, ConnectionInfo.Timeout); @@ -111,6 +128,11 @@ private void ForwardedPort_Closing(object sender, EventArgs eventArgs) // // if the FIN/ACK is not sent in time, the socket will be closed in Close(bool) ShutdownSocket(SocketShutdown.Send); + + if (completionWaitHandle != null) + { + completionWaitHandle.Set(); + } } /// @@ -190,6 +212,13 @@ protected override void Close() /// The data. protected override void OnData(byte[] data) { + if (doSocks) + { + var stream = new MemoryStream(data); + HandleSocks(stream); + return; + } + base.OnData(data); var socket = _socket; @@ -198,5 +227,251 @@ protected override void OnData(byte[] data) SocketAbstraction.Send(socket, data, 0, data.Length); } } + + private void HandleSocks(MemoryStream stream) + { + var version = ReadByte(stream); + switch (version) + { + case 4: + HandleSocks4(stream); + doSocks = false; + return; + case 5: + if (!doSocks5) + { + var authenticationMethodsCount = ReadByte(stream); + var authenticationMethods = new byte[authenticationMethodsCount]; + if (stream.Read(authenticationMethods, 0, authenticationMethods.Length) == 0) + { + return; + } + + if (authenticationMethods.Min() == 0) + { + // no user authentication is one of the authentication methods supported + // by the SOCKS client + SendData(new byte[] { 0x05, 0x00 }); + } + else + { + // the SOCKS client requires authentication, which we currently do not support + SendData(new byte[] { 0x05, 0xFF }); + } + doSocks5 = true; + return; + } + HandleSocks5(stream); + doSocks = false; + return; + } + throw new NotSupportedException(string.Format("SOCKS version {0} is not supported.", version)); + } + + private void HandleSocks4(MemoryStream stream) + { + var commandCode = ReadByte(stream); + if (commandCode == -1) + { + return; + } + + var portBuffer = new byte[2]; + if (stream.Read(portBuffer, 0, portBuffer.Length) == 0) + { + return; + } + + var port = (portBuffer[0] * 256 + portBuffer[1]); + + var ipBuffer = new byte[4]; + if (stream.Read(ipBuffer, 0, ipBuffer.Length) == 0) + { + return; + } + + var ipAddress = new IPAddress(ipBuffer); + + ThreadAbstraction.ExecuteThread(() => + { + var endpoint = new IPEndPoint(ipAddress, port); + + try + { + _socket = SocketAbstraction.Connect(endpoint, ConnectionInfo.Timeout); + } + catch (Exception exp) + { + // send channel open failure message + SendMessage(new ChannelOpenFailureMessage(RemoteChannelNumber, exp.ToString(), ChannelOpenFailureMessage.ConnectFailed, "en")); + completionWaitHandle.Set(); + throw; + } + + SendData(new byte[] { 0x00, 0x5a }); + SendData(portBuffer); + SendData(ipBuffer); + + var buffer = new byte[RemotePacketSize]; + SocketAbstraction.ReadContinuous(_socket, buffer, 0, buffer.Length, SendData); + }); + } + + private void HandleSocks5(MemoryStream stream) + { + var commandCode = ReadByte(stream); + if (commandCode == -1) + { + return; + } + + var reserved = ReadByte(stream); + if (reserved == -1) + { + return; + } + + if (reserved != 0) + { + throw new ProxyException("SOCKS5: 0 is expected for reserved byte."); + } + + var addressType = ReadByte(stream); + if (addressType == -1) + { + // SOCKS client closed connection + return; + } + + var ipAddress = GetSocks5Host(addressType, stream); + if (ipAddress == null) + { + // SOCKS client closed connection + return; + } + + var portBuffer = new byte[2]; + if (stream.Read(portBuffer, 0, portBuffer.Length) == 0) + { + return; + } + + var port = (portBuffer[0] * 256 + portBuffer[1]); + + ThreadAbstraction.ExecuteThread(() => + { + var endpoint = new IPEndPoint(ipAddress, port); + + try + { + _socket = SocketAbstraction.Connect(endpoint, ConnectionInfo.Timeout); + } + catch + { + // send channel open failure message + SendData(CreateSocks5Reply(false)); + completionWaitHandle.Set(); + throw; + } + + SendData(CreateSocks5Reply(true)); + + var buffer = new byte[RemotePacketSize]; + SocketAbstraction.ReadContinuous(_socket, buffer, 0, buffer.Length, SendData); + }); + } + + private IPAddress GetSocks5Host(int addressType, MemoryStream stream) + { + switch (addressType) + { + case 0x01: // IPv4 + { + var addressBuffer = new byte[4]; + if (stream.Read(addressBuffer, 0, 4) == 0) + { + // SOCKS client closed connection + return null; + } + + return new IPAddress(addressBuffer); + } + case 0x03: // Domain name + { + var length = ReadByte(stream); + if (length == -1) + { + // SOCKS client closed connection + return null; + } + var addressBuffer = new byte[length]; + if (stream.Read(addressBuffer, 0, addressBuffer.Length) == 0) + { + // SOCKS client closed connection + return null; + } + + var hostName = SshData.Ascii.GetString(addressBuffer, 0, addressBuffer.Length); + return DnsAbstraction.GetHostAddresses(hostName)[0]; + } + case 0x04: // IPv6 + { + var addressBuffer = new byte[16]; + if (stream.Read(addressBuffer, 0, 16) == 0) + { + return null; + } + + return new IPAddress(addressBuffer); + } + default: + throw new ProxyException(string.Format("SOCKS5: Address type '{0}' is not supported.", addressType)); + } + } + + private static byte[] CreateSocks5Reply(bool success) + { + var socksReply = new byte[ + // SOCKS version + 1 + + // Reply field + 1 + + // Reserved; fixed: 0x00 + 1 + + // Address type; fixed: 0x01 + 1 + + // IPv4 server bound address; fixed: {0x00, 0x00, 0x00, 0x00} + 4 + + // server bound port; fixed: {0x00, 0x00} + 2]; + + socksReply[0] = 0x05; + + if (success) + { + socksReply[1] = 0x00; // succeeded + } + else + { + socksReply[1] = 0x01; // general SOCKS server failure + } + + // reserved + socksReply[2] = 0x00; + + // IPv4 address type + socksReply[3] = 0x01; + + return socksReply; + } + + private int ReadByte(MemoryStream stream) + { + var buffer = new byte[1]; + if (stream.Read(buffer, 0, 1) == 0) + return -1; + + return buffer[0]; + } } -} +} \ No newline at end of file diff --git a/src/Renci.SshNet/ForwardedPortRemote.cs b/src/Renci.SshNet/ForwardedPortRemote.cs index 4b5915de3..f18e00832 100644 --- a/src/Renci.SshNet/ForwardedPortRemote.cs +++ b/src/Renci.SshNet/ForwardedPortRemote.cs @@ -129,6 +129,53 @@ public ForwardedPortRemote(string boundHost, uint boundPort, string host, uint p { } + + /// + /// Initializes a new instance of the class. + /// + /// The bound port. + /// + /// + /// + public ForwardedPortRemote(uint boundPort) + : this (string.Empty, boundPort) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The bound host. + /// The bound port. + /// + /// + /// + public ForwardedPortRemote(string boundHost, uint boundPort) + : this(DnsAbstraction.GetHostAddresses(boundHost)[0], + boundPort) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The bound host address. + /// The bound port. + /// is null. + /// is greater than . + public ForwardedPortRemote(IPAddress boundHostAddress, uint boundPort) + { + if (boundHostAddress == null) + throw new ArgumentNullException("boundHostAddress"); + + boundPort.ValidatePort("boundPort"); + + BoundHostAddress = boundHostAddress; + BoundPort = boundPort; + + _status = ForwardedPortStatus.Stopped; + } + /// /// Starts remote port forwarding. /// @@ -151,7 +198,7 @@ protected override void StartPort() // send global request to start forwarding Session.SendMessage(new TcpIpForwardGlobalRequestMessage(BoundHost, BoundPort)); - // wat for response on global request to start direct tcpip + // wait for response on global request to start direct tcpip Session.WaitOnHandle(_globalRequestResponse); if (!_requestStatus) @@ -250,7 +297,15 @@ private void Session_ChannelOpening(object sender, MessageEventArgs