Skip to content

Commit f48aac4

Browse files
[Optimize] Avoid using tobytes for zero-copies (#1066)
* Avoid using `tobytes` where possible * `send` accepts `Union[memoryview, bytes]` now * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 600b3e7 commit f48aac4

File tree

5 files changed

+13
-12
lines changed

5 files changed

+13
-12
lines changed

proxy/core/connection/connection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111
import logging
1212
from abc import ABC, abstractmethod
13-
from typing import List, Optional
13+
from typing import List, Union, Optional
1414

1515
from .types import tcpConnectionTypes
1616
from ...common.types import TcpOrTlsSocket
@@ -47,7 +47,7 @@ def connection(self) -> TcpOrTlsSocket:
4747
"""Must return the socket connection to use in this class."""
4848
raise TcpConnectionUninitializedException() # pragma: no cover
4949

50-
def send(self, data: bytes) -> int:
50+
def send(self, data: Union[memoryview, bytes]) -> int:
5151
"""Users must handle BrokenPipeError exceptions"""
5252
# logger.info(data)
5353
return self.connection.send(data)
@@ -83,16 +83,16 @@ def flush(self, max_send_size: Optional[int] = None) -> int:
8383
"""Users must handle BrokenPipeError exceptions"""
8484
if not self.has_buffer():
8585
return 0
86-
mv = self.buffer[0].tobytes()
87-
max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE
86+
mv = self.buffer[0]
8887
# TODO: Assemble multiple packets if total
8988
# size remains below max send size.
89+
max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE
9090
sent: int = self.send(mv[:max_send_size])
9191
if sent == len(mv):
9292
self.buffer.pop(0)
9393
self._num_buffer -= 1
9494
else:
95-
self.buffer[0] = memoryview(mv[sent:])
95+
self.buffer[0] = mv[sent:]
9696
del mv
9797
logger.debug('flushed %d bytes to %s' % (sent, self.tag))
9898
return sent

proxy/http/handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,9 @@ def _discover_plugin_klass(self, protocol: int) -> Optional[Type['HttpProtocolHa
267267

268268
def _parse_first_request(self, data: memoryview) -> bool:
269269
# Parse http request
270-
#
271-
# TODO(abhinavsingh): Remove .tobytes after parser is
272-
# memoryview compliant
273270
try:
271+
# TODO(abhinavsingh): Remove .tobytes after parser is
272+
# memoryview compliant
274273
self.request.parse(data.tobytes())
275274
except HttpProtocolException as e: # noqa: WPS329
276275
self.work.queue(BAD_REQUEST_RESPONSE_PKT)

proxy/http/server/web.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,10 @@ async def read_from_descriptors(self, r: Readables) -> bool:
174174

175175
def on_client_data(self, raw: memoryview) -> None:
176176
if self.switched_protocol == httpProtocolTypes.WEBSOCKET:
177-
# TODO(abhinavsingh): Remove .tobytes after websocket frame parser
178-
# is memoryview compliant
177+
# TODO(abhinavsingh): Do we really tobytes() here?
178+
# Websocket parser currently doesn't depend on internal
179+
# buffers, due to which it can directly parse out of
180+
# memory views. But how about large payloads scenarios?
179181
remaining = raw.tobytes()
180182
frame = WebsocketFrame()
181183
while remaining != b'':

proxy/http/websocket/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def run_once(self) -> bool:
9696
if mask & selectors.EVENT_READ and self.on_message:
9797
# TODO: client recvbuf size flag currently not used here
9898
raw = self.recv()
99-
if raw is None or raw.tobytes() == b'':
99+
if raw is None or raw == b'':
100100
self.closed = True
101101
return True
102102
frame = WebsocketFrame()

tests/http/test_protocol_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ async def assert_data_queued(
456456
CRLF,
457457
])
458458
server.queue.assert_called_once()
459-
self.assertEqual(server.queue.call_args_list[0][0][0].tobytes(), pkt)
459+
self.assertEqual(server.queue.call_args_list[0][0][0], pkt)
460460
server.buffer_size.return_value = len(pkt)
461461

462462
async def assert_data_queued_to_server(self, server: mock.Mock) -> None:

0 commit comments

Comments
 (0)