diff --git a/CHANGELOG.md b/CHANGELOG.md index dda86a2..9bb6ee4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,15 +9,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -* +* Rate limiting for WebSocket connections (configurable via `PUSHI_RATE_LIMIT_MESSAGES` and `PUSHI_RATE_LIMIT_WINDOW`) +* Connection limits: global (`PUSHI_MAX_CONNECTIONS_GLOBAL`), per-IP (`PUSHI_MAX_CONNECTIONS_PER_IP`), and per-app (`PUSHI_MAX_CONNECTIONS_PER_APP`) +* Message size limits (`PUSHI_MAX_MESSAGE_SIZE`) +* Channel and event name length limits (`PUSHI_MAX_CHANNEL_NAME_LENGTH`, `PUSHI_MAX_EVENT_NAME_LENGTH`) +* Subscription limits: channels per socket (`PUSHI_MAX_CHANNELS_PER_SOCKET`) and sockets per channel (`PUSHI_MAX_SOCKETS_PER_CHANNEL`) +* Webhook timeout configuration (`PUSHI_WEBHOOK_TIMEOUT`) and concurrent request limits (`PUSHI_WEBHOOK_MAX_CONCURRENT`) +* Proper error messages sent to clients when limits are exceeded ### Changed -* +* Webhook calls are now asynchronous and non-blocking with configurable timeout +* Improved error handling in disconnect flow to ensure proper cleanup even on failures ### Fixed -* +* Added `@appier.private` decorator to app creation endpoint to require authentication +* Fixed missing input validation in `handle_event` for required fields +* Fixed potential KeyError when event data is missing required fields +* Improved security by validating event names before dynamic method dispatch ## [0.3.1] - 2024-01-17 diff --git a/src/pushi/app/controllers/app.py b/src/pushi/app/controllers/app.py index 73e5990..fa68f85 100644 --- a/src/pushi/app/controllers/app.py +++ b/src/pushi/app/controllers/app.py @@ -40,6 +40,7 @@ def list(self): apps = pushi.App.find(map=True) return dict(apps=apps) + @appier.private @appier.route("/apps", "POST") def create(self): app = pushi.App.new() diff --git a/src/pushi/base/state.py b/src/pushi/base/state.py index 1ae822b..a302e75 100644 --- a/src/pushi/base/state.py +++ b/src/pushi/base/state.py @@ -49,6 +49,10 @@ from pushi.base import apn from pushi.base import web +# subscription limit configuration +MAX_CHANNELS_PER_SOCKET = int(os.environ.get("PUSHI_MAX_CHANNELS_PER_SOCKET", "100")) +MAX_SOCKETS_PER_CHANNEL = int(os.environ.get("PUSHI_MAX_SOCKETS_PER_CHANNEL", "10000")) + class AppState(object): """ @@ -254,11 +258,29 @@ def disconnect(self, connection, app_key, socket_id): # then uses it to retrieve the complete set of channels that the # socket is subscribed and then unsubscribe it from them then # removes the reference of the socket in the socket channels map - state = self.get_state(app_key=app_key) + try: + state = self.get_state(app_key=app_key) + except Exception: + # if we can't get the state, nothing to clean up + return + channels = state.socket_channels.get(socket_id, []) channels = copy.copy(channels) + + # unsubscribe from each channel with error handling to ensure + # all channels are attempted even if one fails for channel in channels: - self.unsubscribe(connection, app_key, socket_id, channel) + try: + self.unsubscribe(connection, app_key, socket_id, channel) + except Exception as exception: + # log the error but continue cleaning up other channels + if self.app: + self.app.logger.warning( + "Error unsubscribing from channel '%s' during disconnect: %s" + % (channel, appier.legacy.UNICODE(exception)) + ) + + # ensure socket is removed from socket_channels even if unsubscribe failed if socket_id in state.socket_channels: del state.socket_channels[socket_id] @@ -329,6 +351,11 @@ def subscribe( subscribed = channel in channels if subscribed: raise RuntimeError("Channel already subscribed") + + # check if socket has reached the maximum number of channels + if len(channels) >= MAX_CHANNELS_PER_SOCKET: + raise RuntimeError("Maximum channels per socket exceeded") + channels.append(channel) state.socket_channels[socket_id] = channels @@ -339,6 +366,13 @@ def subscribe( subscribed = socket_id in sockets if subscribed: raise RuntimeError("Socket already subscribed") + + # check if channel has reached the maximum number of sockets + if len(sockets) >= MAX_SOCKETS_PER_CHANNEL: + # remove the channel from the socket's channel list since we can't complete + channels.remove(channel) + raise RuntimeError("Maximum sockets per channel exceeded") + sockets.append(socket_id) state.channel_sockets[channel] = sockets diff --git a/src/pushi/base/web.py b/src/pushi/base/web.py index 752e2fe..fb579e2 100644 --- a/src/pushi/base/web.py +++ b/src/pushi/base/web.py @@ -28,14 +28,20 @@ __license__ = "Apache License, Version 2.0" """ The license for the module """ +import os import json +import netius import netius.clients import pushi from . import handler +# webhook configuration +WEBHOOK_TIMEOUT = int(os.environ.get("PUSHI_WEBHOOK_TIMEOUT", "10")) +WEBHOOK_MAX_CONCURRENT = int(os.environ.get("PUSHI_WEBHOOK_MAX_CONCURRENT", "100")) + class WebHandler(handler.Handler): """ @@ -51,6 +57,7 @@ class WebHandler(handler.Handler): def __init__(self, owner): handler.Handler.__init__(self, owner, name="web") self.subs = {} + self._active_requests = 0 def send(self, app_id, event, json_d, invalid={}): # retrieves the reference to the app structure associated with the @@ -96,17 +103,6 @@ def send(self, app_id, event, json_d, invalid={}): data = json.dumps(json_d) headers = {"content-type": "application/json"} - # creates the on message function that is going to be used at the end of - # the request to be able to close the protocol, this is a clojure and so - # current local variables will be exposed to the method - def on_message(protocol, parser, message): - protocol.close() - - # creates the on close function that will be responsible for the stopping - # of the loop as defined by the web implementation - def on_finish(protocol): - netius.compat_loop(loop).stop() - # iterates over the complete set of URLs that are going to # be notified about the message, each of them is going to # received an HTTP post request with the data @@ -117,23 +113,91 @@ def on_finish(protocol): if url in invalid: continue - # prints a debug message about the web message that - # is going to be sent (includes URL) - self.logger.debug("Sending POST request to '%s'" % url) + # check if we have reached the maximum concurrent requests + if self._active_requests >= WEBHOOK_MAX_CONCURRENT: + self.logger.warning( + "Maximum concurrent webhook requests reached, skipping '%s'" % url + ) + continue + + # adds the current URL to the list of invalid items for + # the current message sending stream (do this before sending + # to prevent duplicate sends in case of rapid events) + invalid[url] = True + + # send the webhook using netius async callback pattern + self._send_webhook(url, headers, data) + + def _send_webhook(self, url, headers, data): + """ + Sends a webhook POST request using netius async event loop with timeout. + Uses protocol.delay() for timeout handling instead of threading. + """ + self._active_requests += 1 + self.logger.debug("Sending POST request to '%s'" % url) - # creates the HTTP protocol to be used in the POST request and - # sets the headers and the data then registers for the message - # event so that the loop and protocol may be closed + # track completion state for timeout handling + completed = [False] + + try: + # creates the HTTP protocol to be used in the POST request loop, protocol = netius.clients.HTTPClient.post_s( url, headers=headers, data=data ) - protocol.bind("message", on_message) - protocol.bind("finish", on_finish) - loop.run_forever() + except Exception as exception: + self.logger.warning( + "Error creating HTTP request to '%s': %s" % (url, str(exception)) + ) + self._active_requests -= 1 + return - # adds the current URL to the list of invalid items for - # the current message sending stream - invalid[url] = True + def on_message(protocol, parser, message): + protocol.close() + + def on_finish(protocol): + completed[0] = True + self._active_requests -= 1 + try: + netius.compat_loop(loop).stop() + except Exception: + pass + + def on_error(protocol, error): + self.logger.warning("Webhook error for '%s': %s" % (url, str(error))) + + def on_timeout(): + if not completed[0]: + self.logger.warning( + "Webhook request to '%s' timed out after %d seconds" + % (url, WEBHOOK_TIMEOUT) + ) + try: + protocol.close() + netius.compat_loop(loop).stop() + except Exception: + pass + # decrement only if on_finish hasn't run + if not completed[0]: + completed[0] = True + self._active_requests -= 1 + + # bind event handlers + protocol.bind("message", on_message) + protocol.bind("finish", on_finish) + protocol.bind("error", on_error) + + # schedule timeout using netius event loop delay + protocol.delay(on_timeout, timeout=WEBHOOK_TIMEOUT) + + try: + loop.run_forever() + except Exception as exception: + self.logger.warning( + "Error sending webhook to '%s': %s" % (url, str(exception)) + ) + if not completed[0]: + completed[0] = True + self._active_requests -= 1 def load(self): subs = pushi.Web.find() diff --git a/src/pushi/net/server.py b/src/pushi/net/server.py index a16db7c..686bb8c 100644 --- a/src/pushi/net/server.py +++ b/src/pushi/net/server.py @@ -28,11 +28,25 @@ __license__ = "Apache License, Version 2.0" """ The license for the module """ +import os +import time import uuid import json import netius.servers +# configuration constants with defaults +MAX_CONNECTIONS_GLOBAL = int(os.environ.get("PUSHI_MAX_CONNECTIONS_GLOBAL", "10000")) +MAX_CONNECTIONS_PER_IP = int(os.environ.get("PUSHI_MAX_CONNECTIONS_PER_IP", "100")) +MAX_CONNECTIONS_PER_APP = int(os.environ.get("PUSHI_MAX_CONNECTIONS_PER_APP", "5000")) +MAX_MESSAGE_SIZE = int(os.environ.get("PUSHI_MAX_MESSAGE_SIZE", "65536")) +MAX_CHANNELS_PER_SOCKET = int(os.environ.get("PUSHI_MAX_CHANNELS_PER_SOCKET", "100")) +MAX_SOCKETS_PER_CHANNEL = int(os.environ.get("PUSHI_MAX_SOCKETS_PER_CHANNEL", "10000")) +RATE_LIMIT_MESSAGES = int(os.environ.get("PUSHI_RATE_LIMIT_MESSAGES", "60")) +RATE_LIMIT_WINDOW = int(os.environ.get("PUSHI_RATE_LIMIT_WINDOW", "60")) +MAX_CHANNEL_NAME_LENGTH = int(os.environ.get("PUSHI_MAX_CHANNEL_NAME_LENGTH", "200")) +MAX_EVENT_NAME_LENGTH = int(os.environ.get("PUSHI_MAX_EVENT_NAME_LENGTH", "200")) + class PushiConnection(netius.servers.WSConnection): def __init__(self, *args, **kwargs): @@ -41,6 +55,8 @@ def __init__(self, *args, **kwargs): self.socket_id = str(uuid.uuid4()) self.channels = [] self.count = 0 + self.message_timestamps = [] + self.remote_ip = None def send_pushi(self, json_d): data = json.dumps(json_d) @@ -48,6 +64,27 @@ def send_pushi(self, json_d): self.count += 1 self.owner.count += 1 + def check_rate_limit(self): + """ + Checks if the connection has exceeded the rate limit. + Returns True if the message should be allowed, False if rate limited. + """ + now = time.time() + window_start = now - RATE_LIMIT_WINDOW + + # remove timestamps outside the window + self.message_timestamps = [ + ts for ts in self.message_timestamps if ts > window_start + ] + + # check if we've exceeded the limit + if len(self.message_timestamps) >= RATE_LIMIT_MESSAGES: + return False + + # record this message + self.message_timestamps.append(now) + return True + def load_app(self): app_key = self.path.rsplit("/", 1)[-1] is_unicode = netius.legacy.is_unicode(app_key) @@ -68,6 +105,8 @@ def __init__(self, state=None, *args, **kwargs): self.state = state self.sockets = {} self.count = 0 + self.connections_by_ip = {} + self.connections_by_app = {} def info_dict(self): info = netius.servers.WSServer.info_dict(self) @@ -76,7 +115,38 @@ def info_dict(self): def on_connection_c(self, connection): netius.servers.WSServer.on_connection_c(self, connection) + + # extract the remote IP address from the connection + remote_ip = connection.address[0] if connection.address else "unknown" + connection.remote_ip = remote_ip + + # check global connection limit + if len(self.sockets) >= MAX_CONNECTIONS_GLOBAL: + self._send_error(connection, "Global connection limit exceeded") + connection.close() + return + + # check per-IP connection limit + ip_count = self.connections_by_ip.get(remote_ip, 0) + if ip_count >= MAX_CONNECTIONS_PER_IP: + self._send_error(connection, "Per-IP connection limit exceeded") + connection.close() + return + + # check per-app connection limit + app_key = connection.app_key + if app_key: + app_count = self.connections_by_app.get(app_key, 0) + if app_count >= MAX_CONNECTIONS_PER_APP: + self._send_error(connection, "Per-app connection limit exceeded") + connection.close() + return + self.connections_by_app[app_key] = app_count + 1 + + # track connection self.sockets[connection.socket_id] = connection + self.connections_by_ip[remote_ip] = ip_count + 1 + self.trigger( "connect", connection=connection, @@ -86,7 +156,25 @@ def on_connection_c(self, connection): def on_connection_d(self, connection): netius.servers.WSServer.on_connection_d(self, connection) - del self.sockets[connection.socket_id] + + # clean up socket tracking + if connection.socket_id in self.sockets: + del self.sockets[connection.socket_id] + + # clean up IP tracking + remote_ip = connection.remote_ip + if remote_ip and remote_ip in self.connections_by_ip: + self.connections_by_ip[remote_ip] -= 1 + if self.connections_by_ip[remote_ip] <= 0: + del self.connections_by_ip[remote_ip] + + # clean up app tracking + app_key = connection.app_key + if app_key and app_key in self.connections_by_app: + self.connections_by_app[app_key] -= 1 + if self.connections_by_app[app_key] <= 0: + del self.connections_by_app[app_key] + self.trigger( "disconnect", connection=connection, @@ -94,6 +182,16 @@ def on_connection_d(self, connection): socket_id=connection.socket_id, ) + def _send_error(self, connection, message): + """ + Sends an error message to the connection. + """ + json_d = dict(event="pusher:error", data=json.dumps(dict(message=message))) + try: + connection.send_pushi(json_d) + except Exception: + pass + def build_connection(self, socket, address, ssl=False): return PushiConnection(self, socket, address, ssl=ssl) @@ -117,13 +215,38 @@ def on_data_ws(self, connection, data): if data == cls.WS_CLOSE_FRAME: return + # check message size limit + if len(data) > MAX_MESSAGE_SIZE: + self._send_error(connection, "Message size exceeds limit") + return + + # check rate limit + if not connection.check_rate_limit(): + self._send_error(connection, "Rate limit exceeded") + return + try: data = data.decode("utf-8") json_d = json.loads(data) except Exception: raise netius.DataError("Invalid message received '%s'" % data) + # validate event field exists and is a string event = json_d.get("event", None) + if not event or not isinstance(event, str): + self._send_error(connection, "Invalid or missing event field") + return + + # validate event name length + if len(event) > MAX_EVENT_NAME_LENGTH: + self._send_error(connection, "Event name too long") + return + + # sanitize event name for method dispatch (only allow alphanumeric and colon/underscore) + if not all(c.isalnum() or c in ":_-" for c in event): + self._send_error(connection, "Invalid characters in event name") + return + event = event.replace(":", "_") method_name = "handle_" + event @@ -140,6 +263,20 @@ def handle_pusher_subscribe(self, connection, json_d): auth = data.get("auth", None) channel_data = data.get("channel_data", None) + # validate channel name + if not channel or not isinstance(channel, str): + self._send_error(connection, "Invalid or missing channel name") + return + + if len(channel) > MAX_CHANNEL_NAME_LENGTH: + self._send_error(connection, "Channel name too long") + return + + # check channel subscription limit per socket + if len(connection.channels) >= MAX_CHANNELS_PER_SOCKET: + self._send_error(connection, "Maximum channels per socket exceeded") + return + self.trigger( "subscribe", connection=connection, @@ -210,12 +347,33 @@ def handle_pusher_latest(self, connection, json_d): connection.send_pushi(json_d) def handle_event(self, connection, json_d): + # validate required fields exist + if "data" not in json_d: + self._send_error(connection, "Missing 'data' field in event") + return + if "event" not in json_d: + self._send_error(connection, "Missing 'event' field in event") + return + if "channel" not in json_d: + self._send_error(connection, "Missing 'channel' field in event") + return + data = json_d["data"] event = json_d["event"] channel = json_d["channel"] echo = json_d.get("echo", False) persist = json_d.get("persist", True) + # validate channel name + if not isinstance(channel, str) or len(channel) > MAX_CHANNEL_NAME_LENGTH: + self._send_error(connection, "Invalid channel name") + return + + # validate event name + if not isinstance(event, str) or len(event) > MAX_EVENT_NAME_LENGTH: + self._send_error(connection, "Invalid event name") + return + if not self.state: return