Skip to content

Commit ac840ae

Browse files
authored
[Work] Refactor for generic work type (#1048)
* Add `TcpOrTlsSocket` type * isort * Update to use Fd executor * Define `HostPort` type * Fix fileno * spellfix
1 parent d7a568e commit ac840ae

File tree

22 files changed

+147
-106
lines changed

22 files changed

+147
-106
lines changed

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,10 @@
317317
(_py_class_role, 'connection.Connection'),
318318
(_py_class_role, 'EventQueue'),
319319
(_py_class_role, 'T'),
320+
(_py_class_role, 'HostPort'),
321+
(_py_class_role, 'TcpOrTlsSocket'),
320322
(_py_obj_role, 'proxy.core.work.threadless.T'),
321323
(_py_obj_role, 'proxy.core.work.work.T'),
322324
(_py_obj_role, 'proxy.core.base.tcp_server.T'),
325+
(_py_obj_role, 'proxy.core.work.fd.T'),
323326
]

proxy/common/types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11+
import ssl
1112
import queue
13+
import socket
1214
import ipaddress
1315
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
1416

@@ -25,5 +27,6 @@
2527
Readables = Selectables
2628
Writables = Selectables
2729
Descriptors = Tuple[Readables, Writables]
28-
2930
IpAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
31+
TcpOrTlsSocket = Union[ssl.SSLSocket, socket.socket]
32+
HostPort = Tuple[str, int]

proxy/common/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from types import TracebackType
2424
from typing import Any, Dict, List, Type, Tuple, Callable, Optional
2525

26+
from .types import HostPort
2627
from .constants import (
2728
CRLF, COLON, HTTP_1_1, IS_WINDOWS, WHITESPACE, DEFAULT_TIMEOUT,
2829
DEFAULT_THREADLESS,
@@ -220,9 +221,9 @@ def wrap_socket(
220221

221222

222223
def new_socket_connection(
223-
addr: Tuple[str, int],
224+
addr: HostPort,
224225
timeout: float = DEFAULT_TIMEOUT,
225-
source_address: Optional[Tuple[str, int]] = None,
226+
source_address: Optional[HostPort] = None,
226227
) -> socket.socket:
227228
conn = None
228229
try:
@@ -252,8 +253,8 @@ def new_socket_connection(
252253
class socket_connection(contextlib.ContextDecorator):
253254
"""Same as new_socket_connection but as a context manager and decorator."""
254255

255-
def __init__(self, addr: Tuple[str, int]):
256-
self.addr: Tuple[str, int] = addr
256+
def __init__(self, addr: HostPort):
257+
self.addr: HostPort = addr
257258
self.conn: Optional[socket.socket] = None
258259
super().__init__()
259260

proxy/core/acceptor/acceptor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..work import LocalExecutor, start_threaded_work, delegate_work_to_pool
2727
from ..event import EventQueue
2828
from ...common.flag import flags
29+
from ...common.types import HostPort
2930
from ...common.logger import Logger
3031
from ...common.backports import NonBlockingQueue
3132
from ...common.constants import DEFAULT_LOCAL_EXECUTOR
@@ -104,7 +105,7 @@ def __init__(
104105
def accept(
105106
self,
106107
events: List[Tuple[selectors.SelectorKey, int]],
107-
) -> List[Tuple[socket.socket, Optional[Tuple[str, int]]]]:
108+
) -> List[Tuple[socket.socket, Optional[HostPort]]]:
108109
works = []
109110
for key, mask in events:
110111
if mask & selectors.EVENT_READ:
@@ -156,8 +157,8 @@ def run(self) -> None:
156157
self.flags.log_format,
157158
)
158159
self.selector = selectors.DefaultSelector()
159-
self._recv_and_setup_socks()
160160
try:
161+
self._recv_and_setup_socks()
161162
if self.flags.threadless and self.flags.local_executor:
162163
self._start_local()
163164
for fileno in self.socks:
@@ -209,7 +210,7 @@ def _stop_local(self) -> None:
209210
self._local_work_queue.put(False)
210211
self._lthread.join()
211212

212-
def _work(self, conn: socket.socket, addr: Optional[Tuple[str, int]]) -> None:
213+
def _work(self, conn: socket.socket, addr: Optional[HostPort]) -> None:
213214
self._total = self._total or 0
214215
if self.flags.threadless:
215216
# Index of worker to which this work should be dispatched

proxy/core/base/tcp_server.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
1313
tcp
1414
"""
15-
import ssl
1615
import socket
1716
import logging
1817
import selectors
1918
from abc import abstractmethod
20-
from typing import Any, Union, TypeVar, Optional
19+
from typing import Any, TypeVar, Optional
2120

2221
from ...core.work import Work
2322
from ...common.flag import flags
24-
from ...common.types import Readables, Writables, SelectableEvents
23+
from ...common.types import (
24+
Readables, Writables, TcpOrTlsSocket, SelectableEvents,
25+
)
2526
from ...common.utils import wrap_socket
2627
from ...core.connection import TcpClientConnection
2728
from ...common.constants import (
@@ -208,9 +209,7 @@ def _encryption_enabled(self) -> bool:
208209
return self.flags.keyfile is not None and \
209210
self.flags.certfile is not None
210211

211-
def _optionally_wrap_socket(
212-
self, conn: socket.socket,
213-
) -> Union[ssl.SSLSocket, socket.socket]:
212+
def _optionally_wrap_socket(self, conn: socket.socket) -> TcpOrTlsSocket:
214213
"""Attempts to wrap accepted client connection using provided certificates.
215214
216215
Shutdown and closes client connection upon error.

proxy/core/connection/client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,32 @@
99
:license: BSD, see LICENSE for more details.
1010
"""
1111
import ssl
12-
import socket
13-
from typing import Tuple, Union, Optional
12+
from typing import Optional
1413

1514
from .types import tcpConnectionTypes
1615
from .connection import TcpConnection, TcpConnectionUninitializedException
16+
from ...common.types import HostPort, TcpOrTlsSocket
1717

1818

1919
class TcpClientConnection(TcpConnection):
2020
"""A buffered client connection object."""
2121

2222
def __init__(
2323
self,
24-
conn: Union[ssl.SSLSocket, socket.socket],
24+
conn: TcpOrTlsSocket,
2525
# optional for unix socket servers
26-
addr: Optional[Tuple[str, int]] = None,
26+
addr: Optional[HostPort] = None,
2727
) -> None:
2828
super().__init__(tcpConnectionTypes.CLIENT)
29-
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn
30-
self.addr: Optional[Tuple[str, int]] = addr
29+
self._conn: Optional[TcpOrTlsSocket] = conn
30+
self.addr: Optional[HostPort] = addr
3131

3232
@property
3333
def address(self) -> str:
3434
return 'unix:client' if not self.addr else '{0}:{1}'.format(self.addr[0], self.addr[1])
3535

3636
@property
37-
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
37+
def connection(self) -> TcpOrTlsSocket:
3838
if self._conn is None:
3939
raise TcpConnectionUninitializedException()
4040
return self._conn

proxy/core/connection/connection.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11-
import ssl
12-
import socket
1311
import logging
1412
from abc import ABC, abstractmethod
15-
from typing import List, Union, Optional
13+
from typing import List, Optional
1614

1715
from .types import tcpConnectionTypes
16+
from ...common.types import TcpOrTlsSocket
1817
from ...common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_MAX_SEND_SIZE
1918

2019

@@ -44,7 +43,7 @@ def __init__(self, tag: int) -> None:
4443

4544
@property
4645
@abstractmethod
47-
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
46+
def connection(self) -> TcpOrTlsSocket:
4847
"""Must return the socket connection to use in this class."""
4948
raise TcpConnectionUninitializedException() # pragma: no cover
5049

proxy/core/connection/pool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..work import Work
2121
from .server import TcpServerConnection
2222
from ...common.flag import flags
23-
from ...common.types import Readables, Writables, SelectableEvents
23+
from ...common.types import HostPort, Readables, Writables, SelectableEvents
2424

2525

2626
logger = logging.getLogger(__name__)
@@ -73,13 +73,13 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):
7373

7474
def __init__(self) -> None:
7575
self.connections: Dict[int, TcpServerConnection] = {}
76-
self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {}
76+
self.pools: Dict[HostPort, Set[TcpServerConnection]] = {}
7777

7878
@staticmethod
7979
def create(**kwargs: Any) -> TcpServerConnection: # pragma: no cover
8080
return TcpServerConnection(**kwargs)
8181

82-
def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]:
82+
def acquire(self, addr: HostPort) -> Tuple[bool, TcpServerConnection]:
8383
"""Returns a reusable connection from the pool.
8484
8585
If none exists, will create and return a new connection."""
@@ -147,7 +147,7 @@ async def handle_events(self, readables: Readables, _writables: Writables) -> bo
147147
self._remove(fileno)
148148
return False
149149

150-
def add(self, addr: Tuple[str, int]) -> TcpServerConnection:
150+
def add(self, addr: HostPort) -> TcpServerConnection:
151151
"""Creates, connects and adds a new connection to the pool.
152152
153153
Returns newly created connection.

proxy/core/connection/server.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
:license: BSD, see LICENSE for more details.
1010
"""
1111
import ssl
12-
import socket
13-
from typing import Tuple, Union, Optional
12+
from typing import Optional
1413

1514
from .types import tcpConnectionTypes
1615
from .connection import TcpConnection, TcpConnectionUninitializedException
16+
from ...common.types import HostPort, TcpOrTlsSocket
1717
from ...common.utils import new_socket_connection
1818

1919

@@ -22,20 +22,20 @@ class TcpServerConnection(TcpConnection):
2222

2323
def __init__(self, host: str, port: int) -> None:
2424
super().__init__(tcpConnectionTypes.SERVER)
25-
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None
26-
self.addr: Tuple[str, int] = (host, port)
25+
self._conn: Optional[TcpOrTlsSocket] = None
26+
self.addr: HostPort = (host, port)
2727
self.closed = True
2828

2929
@property
30-
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
30+
def connection(self) -> TcpOrTlsSocket:
3131
if self._conn is None:
3232
raise TcpConnectionUninitializedException()
3333
return self._conn
3434

3535
def connect(
3636
self,
37-
addr: Optional[Tuple[str, int]] = None,
38-
source_address: Optional[Tuple[str, int]] = None,
37+
addr: Optional[HostPort] = None,
38+
source_address: Optional[HostPort] = None,
3939
) -> None:
4040
assert self._conn is None
4141
self._conn = new_socket_connection(

proxy/core/ssh/handler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
:license: BSD, see LICENSE for more details.
1010
"""
1111
import argparse
12-
from typing import TYPE_CHECKING, Tuple
12+
from typing import TYPE_CHECKING
1313

1414

1515
if TYPE_CHECKING: # pragma: no cover
16+
from ...common.types import HostPort
1617
try:
1718
from paramiko.channel import Channel
1819
except ImportError:
@@ -28,7 +29,7 @@ def __init__(self, flags: argparse.Namespace) -> None:
2829
def on_connection(
2930
self,
3031
chan: 'Channel',
31-
origin: Tuple[str, int],
32-
server: Tuple[str, int],
32+
origin: 'HostPort',
33+
server: 'HostPort',
3334
) -> None:
3435
pass

0 commit comments

Comments
 (0)