diff --git a/realtime/_async/channel.py b/realtime/_async/channel.py index b39ddbd..b4b3069 100644 --- a/realtime/_async/channel.py +++ b/realtime/_async/channel.py @@ -16,6 +16,7 @@ RealtimeSubscribeStates, ) +from ..logging_util import TokenMaskingFilter from ..transformers import http_endpoint_url from .presence import ( AsyncRealtimePresence, @@ -29,6 +30,7 @@ from .client import AsyncRealtimeClient logger = logging.getLogger(__name__) +logger.addFilter(TokenMaskingFilter()) class AsyncRealtimeChannel: diff --git a/realtime/_async/client.py b/realtime/_async/client.py index fda8487..88bbc76 100644 --- a/realtime/_async/client.py +++ b/realtime/_async/client.py @@ -8,6 +8,7 @@ import websockets from ..exceptions import NotConnectedError +from ..logging_util import TokenMaskingFilter from ..message import Message from ..transformers import http_endpoint_url from ..types import ( @@ -21,6 +22,7 @@ from .channel import AsyncRealtimeChannel, RealtimeChannelOptions logger = logging.getLogger(__name__) +logger.addFilter(TokenMaskingFilter()) def ensure_connection(func: Callback): @@ -123,7 +125,7 @@ async def connect(self) -> None: while retries < self.max_retries: try: - self.ws_connection = await websockets.connect(self.url) + self.ws_connection = await websockets.connect(self.url, logger=logger) if self.ws_connection.open: logger.info("Connection was successful") return await self._on_connect() diff --git a/realtime/_async/push.py b/realtime/_async/push.py index 06c62fb..fe17127 100644 --- a/realtime/_async/push.py +++ b/realtime/_async/push.py @@ -2,12 +2,14 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional +from ..logging_util import TokenMaskingFilter from ..types import DEFAULT_TIMEOUT, Callback, _Hook if TYPE_CHECKING: from .channel import AsyncRealtimeChannel logger = logging.getLogger(__name__) +logger.addFilter(TokenMaskingFilter()) class AsyncPush: diff --git a/realtime/logging_util.py b/realtime/logging_util.py new file mode 100644 index 0000000..cf04c30 --- /dev/null +++ b/realtime/logging_util.py @@ -0,0 +1,43 @@ +import copy +import logging +import re + +# redaction regex for detecting JWT tokens +#
.. +# character set [a-zA-Z0-9_-] +# \w covers [a-zA-Z0-9] +redact = r"(eyJh[-_\w]*\.)([-_\w]*)\." + + +def gred(g): + """Redact the payload of the JWT, keeping the header and signature""" + return f"{g.group(1)}REDACTED." if len(g.groups()) > 1 else g + + +class TokenMaskingFilter(logging.Filter): + """Mask access_tokens in logs""" + + def filter(self, record): + record.msg = self.sanitize_line(record.msg) + record.args = self.sanitize_args(record.args) + return True + + @staticmethod + def sanitize_args(d): + if isinstance(d, dict): + d = d.copy() # so we don't overwrite anything + for k, v in d.items(): + d[k] = self.sanitize_line(v) + elif isinstance(d, tuple): + # need a deepcopy of tuple turned to a list, as to not change the original values + # otherwise we end up changing the items at the original memory location of the passed in tuple + y = copy.deepcopy(list(d)) + for x, value in enumerate(y): + if isinstance(value, str): + y[x] = re.sub(redact, gred, value) + return tuple(y) # convert the list back to a tuple + return d + + @staticmethod + def sanitize_line(line): + return re.sub(redact, gred, line)