diff --git a/.bumpversion.cfg b/.bumpversion.cfg index f4ccffb..8094aca 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.4.1 +current_version = 0.4.2 commit = True tag = True diff --git a/CHANGELOG.md b/CHANGELOG.md index e240e9d..88952f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## [0.4.1] - 2025-06-05 +### Added +- Custom log format in config.ini +- Add __main__.py to run module +- Colored log + ## [0.3.3] - 2025-05-30 ### Added - Change pylint to flake8 diff --git a/pyproxy/__init__.py b/pyproxy/__init__.py index d1ba048..dd3677d 100644 --- a/pyproxy/__init__.py +++ b/pyproxy/__init__.py @@ -5,7 +5,7 @@ import os -__version__ = "0.4.1" +__version__ = "0.4.2" if os.path.isdir("pyproxy/monitoring"): __slim__ = False diff --git a/pyproxy/handlers/client.py b/pyproxy/handlers/client.py index 406bbd1..bb3bac7 100644 --- a/pyproxy/handlers/client.py +++ b/pyproxy/handlers/client.py @@ -39,9 +39,7 @@ def __init__( shortcuts, custom_header, active_connections, - proxy_enable, - proxy_host, - proxy_port, + proxy_config, ): self.html_403 = html_403 self.logger_config = logger_config @@ -58,11 +56,32 @@ def __init__( self.console_logger = console_logger self.config_shortcuts = shortcuts self.config_custom_header = custom_header - self.proxy_enable = proxy_enable - self.proxy_host = proxy_host - self.proxy_port = proxy_port + self.proxy_config = proxy_config self.active_connections = active_connections + def _create_handler(self, handler_class, **extra_kwargs): + """ + Factory to create handler instance with shared common parameters. + """ + params = dict( + html_403=self.html_403, + logger_config=self.logger_config, + filter_config=self.filter_config, + filter_queue=self.filter_queue, + filter_result_queue=self.filter_result_queue, + shortcuts_queue=self.shortcuts_queue, + shortcuts_result_queue=self.shortcuts_result_queue, + custom_header_queue=self.custom_header_queue, + custom_header_result_queue=self.custom_header_result_queue, + console_logger=self.console_logger, + shortcuts=self.config_shortcuts, + custom_header=self.config_custom_header, + proxy_config=self.proxy_config, + active_connections=self.active_connections, + ) + params.update(extra_kwargs) + return handler_class(**params) + def handle_client(self, client_socket): """ Handles an incoming client connection by processing the request and forwarding @@ -82,45 +101,13 @@ def handle_client(self, client_socket): first_line = request.decode(errors="ignore").split("\n")[0] if first_line.startswith("CONNECT"): - client_https_handler = HttpsHandler( - html_403=self.html_403, - logger_config=self.logger_config, - filter_config=self.filter_config, + https_handler = self._create_handler( + HttpsHandler, ssl_config=self.ssl_config, - filter_queue=self.filter_queue, - filter_result_queue=self.filter_result_queue, - shortcuts_queue=self.shortcuts_queue, - shortcuts_result_queue=self.shortcuts_result_queue, cancel_inspect_queue=self.cancel_inspect_queue, cancel_inspect_result_queue=self.cancel_inspect_result_queue, - custom_header_queue=self.custom_header_queue, - custom_header_result_queue=self.custom_header_result_queue, - console_logger=self.console_logger, - shortcuts=self.config_shortcuts, - custom_header=self.config_custom_header, - proxy_enable=self.proxy_enable, - proxy_host=self.proxy_host, - proxy_port=self.proxy_port, - active_connections=self.active_connections, ) - client_https_handler.handle_https_connection(client_socket, first_line) + https_handler.handle_https_connection(client_socket, first_line) else: - client_http_handler = HttpHandler( - html_403=self.html_403, - logger_config=self.logger_config, - filter_config=self.filter_config, - filter_queue=self.filter_queue, - filter_result_queue=self.filter_result_queue, - shortcuts_queue=self.shortcuts_queue, - shortcuts_result_queue=self.shortcuts_result_queue, - custom_header_queue=self.custom_header_queue, - custom_header_result_queue=self.custom_header_result_queue, - console_logger=self.console_logger, - shortcuts=self.config_shortcuts, - custom_header=self.config_custom_header, - proxy_enable=self.proxy_enable, - proxy_host=self.proxy_host, - proxy_port=self.proxy_port, - active_connections=self.active_connections, - ) - client_http_handler.handle_http_request(client_socket, request) + http_handler = self._create_handler(HttpHandler) + http_handler.handle_http_request(client_socket, request) diff --git a/pyproxy/handlers/http.py b/pyproxy/handlers/http.py index a71612d..b583b28 100644 --- a/pyproxy/handlers/http.py +++ b/pyproxy/handlers/http.py @@ -8,8 +8,9 @@ import socket import os import threading +from urllib.parse import urlparse -from pyproxy.utils.http_req import extract_headers, parse_url +from pyproxy.utils.http_req import extract_headers class HttpHandler: @@ -34,9 +35,7 @@ def __init__( shortcuts, custom_header, active_connections, - proxy_enable, - proxy_host, - proxy_port, + proxy_config, ): self.html_403 = html_403 self.logger_config = logger_config @@ -50,11 +49,81 @@ def __init__( self.console_logger = console_logger self.config_shortcuts = shortcuts self.config_custom_header = custom_header - self.proxy_enable = proxy_enable - self.proxy_host = proxy_host - self.proxy_port = proxy_port + self.proxy_config = proxy_config self.active_connections = active_connections + def _get_modified_headers(self, url, request_text): + """ + Extract headers from a request + """ + headers = extract_headers(request_text) + self.custom_header_queue.put(url) + try: + new_headers = self.custom_header_result_queue.get(timeout=5) + headers.update(new_headers) + except Exception: + self.console_logger.warning( + "Timeout while getting custom headers for %s", url + ) + return headers + + def _rebuild_http_request(self, request_line, headers, body=""): + """ + Reconstructs an HTTP request with the new headers. + """ + header_lines = [f"{key}: {value}" for key, value in headers.items()] + reconstructed_headers = "\r\n".join(header_lines) + return f"{request_line}\r\n{reconstructed_headers}\r\n\r\n{body}".encode() + + def _apply_shortcut(self, url: str) -> str | None: + """ + Checks if a shortcut is defined for the given domain. + """ + if self.config_shortcuts and os.path.isfile(self.config_shortcuts): + parsed_url = urlparse(url) + domain = parsed_url.hostname + self.shortcuts_queue.put(domain) + try: + return self.shortcuts_result_queue.get(timeout=5) + except Exception: + self.console_logger.warning( + "Timeout while getting shortcut for %s", url + ) + return None + + def _is_blocked(self, url: str) -> bool: + """ + Checks if a URL is blocked by the configuration filter. + """ + if not self.filter_config.no_filter: + self.filter_queue.put(url) + try: + result = self.filter_result_queue.get(timeout=5) + return result[1] == "Blocked" + except Exception: + self.console_logger.warning("Timeout while filtering %s", url) + return False + + def _send_403(self, client_socket, url, first_line): + """ + Sends an HTTP 403 Forbidden response to the client. + """ + if not self.logger_config.no_logging_block: + self.logger_config.block_logger.info( + "%s - %s - %s", client_socket.getpeername()[0], url, first_line + ) + with open(self.html_403, "r", encoding="utf-8") as f: + custom_403_page = f.read() + response = ( + f"HTTP/1.1 403 Forbidden\r\n" + f"Content-Length: {len(custom_403_page)}\r\n" + f"\r\n" + f"{custom_403_page}" + ) + client_socket.sendall(response.encode()) + client_socket.close() + self.active_connections.pop(threading.get_ident(), None) + def handle_http_request(self, client_socket, request): """ Processes an HTTP request, checks for URL filtering, applies shortcuts, @@ -67,16 +136,8 @@ def handle_http_request(self, client_socket, request): first_line = request.decode(errors="ignore").split("\n")[0] url = first_line.split(" ")[1] - if self.config_custom_header and os.path.isfile(self.config_custom_header): - headers = extract_headers(request.decode(errors="ignore")) - self.custom_header_queue.put(url) - new_headers = self.custom_header_result_queue.get(timeout=5) - headers.update(new_headers) - if self.config_shortcuts and os.path.isfile(self.config_shortcuts): - domain, _ = parse_url(url) - self.shortcuts_queue.put(domain) - shortcut_url = self.shortcuts_result_queue.get(timeout=5) + shortcut_url = self._apply_shortcut(url) if shortcut_url: response = ( f"HTTP/1.1 302 Found\r\n" @@ -90,27 +151,12 @@ def handle_http_request(self, client_socket, request): self.active_connections.pop(threading.get_ident(), None) return - if not self.filter_config.no_filter: - self.filter_queue.put(url) - result = self.filter_result_queue.get(timeout=5) - if result[1] == "Blocked": - if not self.logger_config.no_logging_block: - self.logger_config.block_logger.info( - "%s - %s - %s", client_socket.getpeername()[0], url, first_line - ) - with open(self.html_403, "r", encoding="utf-8") as f: - custom_403_page = f.read() - response = ( - f"HTTP/1.1 403 Forbidden\r\n" - f"Content-Length: {len(custom_403_page)}\r\n" - f"\r\n" - f"{custom_403_page}" - ) - client_socket.sendall(response.encode()) - client_socket.close() - self.active_connections.pop(threading.get_ident(), None) - return - server_host, _ = parse_url(url) + if self._is_blocked(url): + self._send_403(client_socket, url, first_line) + return + + parsed_url = urlparse(url) + server_host = parsed_url.hostname if not self.logger_config.no_logging_access: self.logger_config.access_logger.info( "%s - %s - %s", @@ -120,20 +166,16 @@ def handle_http_request(self, client_socket, request): ) if self.config_custom_header and os.path.isfile(self.config_custom_header): - request_lines = request.decode(errors="ignore").split("\r\n") - request_line = request_lines[0] # GET / HTTP/1.1 - - header_lines = [f"{key}: {value}" for key, value in headers.items()] - reconstructed_headers = "\r\n".join(header_lines) - - if "\r\n\r\n" in request.decode(errors="ignore"): - body = request.decode(errors="ignore").split("\r\n\r\n", 1)[1] - else: - body = "" - - modified_request = ( - f"{request_line}\r\n{reconstructed_headers}\r\n\r\n{body}".encode() + request_text = request.decode(errors="ignore") + request_lines = request_text.split("\r\n") + headers = self._get_modified_headers(url, request_text) + request_line = request_lines[0] + body = ( + request_text.split("\r\n\r\n", 1)[1] + if "\r\n\r\n" in request_text + else "" ) + modified_request = self._rebuild_http_request(request_line, headers, body) self.forward_request_to_server(client_socket, modified_request, url) @@ -149,10 +191,14 @@ def forward_request_to_server(self, client_socket, request, url): request (bytes): The raw HTTP request sent by the client. url (str): The target URL from the HTTP request. """ - if self.proxy_enable: - server_host, server_port = self.proxy_host, self.proxy_port + if self.proxy_config.enable: + server_host, server_port = self.proxy_config.host, self.proxy_config.port else: - server_host, server_port = parse_url(url) + parsed_url = urlparse(url) + server_host = parsed_url.hostname + server_port = parsed_url.port or ( + 443 if parsed_url.scheme == "https" else 80 + ) thread_id = threading.get_ident() if thread_id in self.active_connections: diff --git a/pyproxy/handlers/https.py b/pyproxy/handlers/https.py index d01f50e..ddabeeb 100644 --- a/pyproxy/handlers/https.py +++ b/pyproxy/handlers/https.py @@ -42,9 +42,7 @@ def __init__( shortcuts, custom_header, active_connections, - proxy_enable, - proxy_host, - proxy_port, + proxy_config, ): self.html_403 = html_403 self.logger_config = logger_config @@ -61,11 +59,149 @@ def __init__( self.console_logger = console_logger self.config_shortcuts = shortcuts self.config_custom_header = custom_header - self.proxy_enable = proxy_enable - self.proxy_host = proxy_host - self.proxy_port = proxy_port + self.proxy_config = proxy_config self.active_connections = active_connections + def _is_blocked(self, url: str) -> bool: + """ + Checks if a URL is blocked by the configuration filter. + """ + if not self.filter_config.no_filter: + self.filter_queue.put(url) + try: + result = self.filter_result_queue.get(timeout=5) + return result[1] == "Blocked" + except Exception: + self.console_logger.warning("Timeout while filtering %s", url) + return False + + def _send_403(self, client_socket, url, first_line): + """ + Sends an HTTP 403 Forbidden response to the client. + """ + if not self.logger_config.no_logging_block: + self.logger_config.block_logger.info( + "%s - %s - %s", client_socket.getpeername()[0], url, first_line + ) + with open(self.html_403, "r", encoding="utf-8") as f: + custom_403_page = f.read() + response = ( + f"HTTP/1.1 403 Forbidden\r\n" + f"Content-Length: {len(custom_403_page)}\r\n" + f"\r\n" + f"{custom_403_page}" + ) + client_socket.sendall(response.encode()) + client_socket.close() + self.active_connections.pop(threading.get_ident(), None) + + def _should_skip_inspection(self, server_host: str) -> bool: + """ + Determine if SSL inspection should be skipped for the given host. + """ + if ( + self.ssl_config.ssl_inspect + and self.ssl_config.cancel_inspect + and os.path.isfile(self.ssl_config.cancel_inspect) + ): + self.cancel_inspect_queue.put(server_host) + return self.cancel_inspect_result_queue.get(timeout=5) + return False + + def _establish_server_connection(self, server_host, server_port): + """ + Create and return a socket connected to the target server. + """ + if self.proxy_config.enable: + next_proxy_socket = socket.create_connection( + (self.proxy_config.host, self.proxy_config.port) + ) + connect_command = ( + f"CONNECT {server_host}:{server_port} HTTP/1.1\r\n" + f"Host: {server_host}:{server_port}\r\n\r\n" + ) + next_proxy_socket.sendall(connect_command.encode()) + + response = b"" + while b"\r\n\r\n" not in response: + chunk = next_proxy_socket.recv(4096) + if not chunk: + raise ConnectionError("Connection to next proxy failed") + response += chunk + + if b"200 Connection Established" not in response: + raise ConnectionRefusedError("Next proxy refused CONNECT") + + return next_proxy_socket + else: + return socket.create_connection((server_host, server_port)) + + def _wrap_client_socket_with_ssl(self, client_socket, cert_path, key_path): + """ + Wrap the client socket with an SSL context for interception. + """ + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + client_context.load_cert_chain(certfile=cert_path, keyfile=key_path) + client_context.options |= ( + ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + ) + client_context.load_verify_locations(self.ssl_config.inspect_ca_cert) + + ssl_client_socket = client_context.wrap_socket( + client_socket, server_side=True, do_handshake_on_connect=False + ) + ssl_client_socket.do_handshake() + return ssl_client_socket + + def _wrap_server_socket_with_ssl(self, server_socket, server_host): + """ + Wrap the server socket with an SSL context to enable encrypted communication. + """ + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if self.proxy_config.enable: + server_context.check_hostname = False + server_context.verify_mode = ssl.CERT_NONE + else: + server_context.load_default_certs() + + ssl_server_socket = server_context.wrap_socket( + server_socket, + server_hostname=server_host, + do_handshake_on_connect=True, + ) + return ssl_server_socket + + def _process_first_ssl_request(self, ssl_client_socket, server_host): + """ + Reads and processes the first SSL client request, extracts the method and full URL. + """ + try: + first_request = ssl_client_socket.recv(4096).decode(errors="ignore") + if not first_request: + raise ConnectionError("Empty request received") + + request_line = first_request.split("\r\n")[0] + method, path, _ = request_line.split(" ") + + full_url = f"https://{server_host}{path}" + + if self._is_blocked(f"{server_host}{path}"): + return None, full_url, True + + if not self.logger_config.no_logging_access: + self.logger_config.access_logger.info( + "%s - %s - %s %s", + ssl_client_socket.getpeername()[0], + f"https://{server_host}", + method, + full_url, + ) + + return first_request, full_url, False + except Exception as e: + self.logger_config.error_logger.error(f"SSL request processing error : {e}") + return None, None, False + def handle_https_connection(self, client_socket, first_line): """ Handles HTTPS connections by establishing a connection with the target server @@ -79,146 +215,45 @@ def handle_https_connection(self, client_socket, first_line): server_host, server_port = target.split(":") server_port = int(server_port) - if not self.filter_config.no_filter: - self.filter_queue.put(target) - result = self.filter_result_queue.get(timeout=5) - if result[1] == "Blocked": - if not self.logger_config.no_logging_block: - self.logger_config.block_logger.info( - "%s - %s - %s", - client_socket.getpeername()[0], - target, - first_line, - ) - with open(self.html_403, "r", encoding="utf-8") as f: - custom_403_page = f.read() - response = ( - f"HTTP/1.1 403 Forbidden\r\n" - f"Content-Length: {len(custom_403_page)}\r\n" - f"\r\n" - f"{custom_403_page}" - ) - client_socket.sendall(response.encode()) - client_socket.close() - self.active_connections.pop(threading.get_ident(), None) - return + if self._is_blocked(target): + self._send_403(client_socket, target, first_line) + return - not_inspect = False - if ( - self.ssl_config.ssl_inspect - and self.ssl_config.cancel_inspect - and os.path.isfile(self.ssl_config.cancel_inspect) - ): - self.cancel_inspect_queue.put(server_host) - not_inspect = self.cancel_inspect_result_queue.get(timeout=5) + not_inspect = self._should_skip_inspection(server_host) if self.ssl_config.ssl_inspect and not not_inspect: - cert_path, key_path = generate_certificate( - server_host, - self.ssl_config.inspect_certs_folder, - self.ssl_config.inspect_ca_cert, - self.ssl_config.inspect_ca_key, - ) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - client_context.load_cert_chain(certfile=cert_path, keyfile=key_path) - client_context.options |= ( - ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 - ) - client_context.load_verify_locations(self.ssl_config.inspect_ca_cert) - try: - client_socket.sendall(b"HTTP/1.1 200 Connection Established\r\n\r\n") - ssl_client_socket = client_context.wrap_socket( - client_socket, server_side=True, do_handshake_on_connect=False - ) - ssl_client_socket.do_handshake() - - if self.proxy_enable: - next_proxy_socket = socket.create_connection( - (self.proxy_host, self.proxy_port) - ) - connect_command = ( - f"CONNECT {server_host}:{server_port} HTTP/1.1\r\n" - f"Host: {server_host}:{server_port}\r\n\r\n" - ) - next_proxy_socket.sendall(connect_command.encode()) - - response = b"" - while b"\r\n\r\n" not in response: - chunk = next_proxy_socket.recv(4096) - if not chunk: - raise ConnectionError("Connection to next proxy failed") - response += chunk - - if b"200 Connection Established" not in response: - raise ConnectionRefusedError("Next proxy refused CONNECT") - - server_socket = next_proxy_socket - else: - server_socket = socket.create_connection((server_host, server_port)) - - server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - if self.proxy_enable: - server_context.check_hostname = False - server_context.verify_mode = ssl.CERT_NONE - else: - server_context.load_default_certs() - - ssl_server_socket = server_context.wrap_socket( - server_socket, - server_hostname=server_host, - do_handshake_on_connect=True, + cert_path, key_path = generate_certificate( + server_host, + self.ssl_config.inspect_certs_folder, + self.ssl_config.inspect_ca_cert, + self.ssl_config.inspect_ca_key, ) - try: - first_request = ssl_client_socket.recv(4096).decode(errors="ignore") - request_line = first_request.split("\r\n")[0] - method, path, _ = request_line.split(" ") - - full_url = f"https://{server_host}{path}" - - if not self.filter_config.no_filter: - self.filter_queue.put(f"{server_host}{path}") - result = self.filter_result_queue.get(timeout=5) - if result[1] == "Blocked": - if not self.logger_config.no_logging_block: - self.logger_config.block_logger.info( - "%s - %s - %s", - ssl_client_socket.getpeername()[0], - target, - first_line, - ) - with open(self.html_403, "r", encoding="utf-8") as f: - custom_403_page = f.read() - response = ( - f"HTTP/1.1 403 Forbidden\r\n" - f"Content-Length: {len(custom_403_page)}\r\n" - f"\r\n" - f"{custom_403_page}" - ) - ssl_client_socket.sendall(response.encode()) - ssl_client_socket.close() - self.active_connections.pop(threading.get_ident(), None) - return - - if not self.logger_config.no_logging_access: - self.logger_config.access_logger.info( - "%s - %s - %s %s", - ssl_client_socket.getpeername()[0], - f"https://{server_host}", - method, - full_url, - ) + client_socket.sendall(b"HTTP/1.1 200 Connection Established\r\n\r\n") - ssl_server_socket.sendall(first_request.encode()) + ssl_client_socket = self._wrap_client_socket_with_ssl( + client_socket, cert_path, key_path + ) - except ValueError: - self.console_logger.error( - "Error parsing request: malformed request line." - ) + server_socket = self._establish_server_connection( + server_host, server_port + ) + ssl_server_socket = self._wrap_server_socket_with_ssl( + server_socket, server_host + ) - except (socket.error, ssl.SSLError) as e: - self.console_logger.error("Network or SSL error : %s", str(e)) + first_request, full_url, is_blocked = self._process_first_ssl_request( + ssl_client_socket, server_host + ) + if is_blocked: + self._send_403(ssl_client_socket, target, first_line) + return + if first_request is None: + ssl_client_socket.close() + return + + ssl_server_socket.sendall(first_request.encode()) self.transfer_data_between_sockets(ssl_client_socket, ssl_server_socket) diff --git a/pyproxy/pyproxy.py b/pyproxy/pyproxy.py index 981961d..da4bdcb 100644 --- a/pyproxy/pyproxy.py +++ b/pyproxy/pyproxy.py @@ -6,7 +6,14 @@ from .server import ProxyServer from .utils.args import parse_args, load_config, get_config_value, str_to_bool -from .utils.config import ProxyConfigLogger, ProxyConfigFilter, ProxyConfigSSL +from .utils.config import ( + ProxyConfigLogger, + ProxyConfigFilter, + ProxyConfigSSL, + ProxyConfigMonitoring, + ProxyConfigProxy, + ProxyConfigMain, +) def main(): @@ -17,31 +24,39 @@ def main(): args = parse_args() config = load_config(args.config_file) - host = get_config_value(args, config, "host", "Server", "0.0.0.0") # nosec - port = int(get_config_value(args, config, "port", "Server", 8080)) # nosec - debug = get_config_value(args, config, "debug", "Logging", False) - html_403 = get_config_value(args, config, "html_403", "Files", "assets/403.html") - shortcuts = get_config_value( - args, config, "shortcuts", "Options", "config/shortcuts.txt" + main_config = ProxyConfigMain( + host=get_config_value(args, config, "host", "Server", "0.0.0.0"), # nosec + port=int(get_config_value(args, config, "port", "Server", 8080)), # nosec + debug=str_to_bool(get_config_value(args, config, "debug", "Logging", False)), + html_403=get_config_value(args, config, "html_403", "Files", "assets/403.html"), + shortcuts=get_config_value( + args, config, "shortcuts", "Options", "config/shortcuts.txt" + ), + custom_header=get_config_value( + args, config, "custom_header", "Options", "config/custom_header.json" + ), + authorized_ips=get_config_value( + args, config, "authorized_ips", "Options", "config/authorized_ips.txt" + ), ) - custom_header = get_config_value( - args, config, "custom_header", "Options", "config/custom_header.json" + + monitoring_config = ProxyConfigMonitoring( + flask_port=get_config_value(args, config, "flask_port", "Monitoring", 5000), + flask_pass=get_config_value( + args, config, "flask_pass", "Monitoring", "password" + ), ) - authorized_ips = get_config_value( - args, config, "authorized_ips", "Options", "config/authorized_ips.txt" + + proxy_config = ProxyConfigProxy( + enable=str_to_bool( + get_config_value(args, config, "proxy_enable", "Proxy", False) + ), + host=get_config_value(args, config, "proxy_host", "Proxy", "127.0.0.1"), + port=get_config_value(args, config, "proxy_port", "Proxy", 8081), ) - flask_port = get_config_value(args, config, "flask_port", "Monitoring", 5000) - flask_pass = get_config_value(args, config, "flask_pass", "Monitoring", "password") - proxy_enable = get_config_value(args, config, "proxy_enable", "Proxy", False) - proxy_host = get_config_value(args, config, "proxy_host", "Proxy", "127.0.0.1") - proxy_port = get_config_value(args, config, "proxy_port", "Proxy", 8081) - console_format = None - if config.has_section("Logging") and config.has_option("Logging", "console_format"): - console_format = config.get("Logging", "console_format") - datefmt = None - if config.has_section("Logging") and config.has_option("Logging", "datefmt"): - datefmt = config.get("Logging", "datefmt") + console_format = config.get("Logging", "console_format", fallback=None) + datefmt = config.get("Logging", "datefmt", fallback=None) logger_config = ProxyConfigLogger( access_log=get_config_value( @@ -96,25 +111,12 @@ def main(): ) proxy = ProxyServer( - host=host, - port=port, - debug=str_to_bool(debug), + main_config=main_config, logger_config=logger_config, filter_config=filter_config, ssl_config=ssl_config, - flask_port=flask_port, - flask_pass=flask_pass, - html_403=html_403, - shortcuts=shortcuts, - custom_header=custom_header, - authorized_ips=authorized_ips, - proxy_enable=str_to_bool(proxy_enable), - proxy_host=proxy_host, - proxy_port=proxy_port, + monitoring_config=monitoring_config, + proxy_config=proxy_config, ) proxy.start() - - -if __name__ == "__main__": - main() diff --git a/pyproxy/server.py b/pyproxy/server.py index f25dcf2..53c64e7 100644 --- a/pyproxy/server.py +++ b/pyproxy/server.py @@ -12,6 +12,7 @@ import logging import multiprocessing import os +import ssl import time import ipaddress @@ -57,28 +58,19 @@ class ProxyServer: def __init__( self, - host, - port, - debug, + main_config, logger_config, filter_config, - html_403, ssl_config, - shortcuts, - custom_header, - flask_port, - flask_pass, - proxy_enable, - proxy_host, - proxy_port, - authorized_ips, + monitoring_config, + proxy_config, ): """ Initialize the ProxyServer with configuration parameters. """ - self.host_port = (host, port) - self.debug = debug - self.html_403 = html_403 + self.host_port = (main_config.host, main_config.port) + self.debug = main_config.debug + self.html_403 = main_config.html_403 self.active_connections = {} self.logger_config = logger_config @@ -86,16 +78,13 @@ def __init__( self.ssl_config = ssl_config # Monitoring - self.flask_port = flask_port - self.flask_pass = flask_pass + self.monitoring_config = monitoring_config # Proxy - self.proxy_enable = proxy_enable - self.proxy_host = proxy_host - self.proxy_port = proxy_port + self.proxy_config = proxy_config # Authorized IPS - self.authorized_ips = authorized_ips + self.authorized_ips = main_config.authorized_ips self.allowed_subnets = None # Process communication queues @@ -124,8 +113,8 @@ def __init__( ) # Configuration files - self.config_shortcuts = shortcuts - self.config_custom_header = custom_header + self.config_shortcuts = main_config.shortcuts + self.config_custom_header = main_config.custom_header def _initialize_processes(self): """ @@ -226,6 +215,46 @@ def _load_authorized_ips(self): ) self.allowed_subnets = None + def _validate_ssl_inspection_files(self): + """ + Validate SSL Inspection cert/key. + """ + required_files = [ + self.ssl_config.inspect_ca_cert, + self.ssl_config.inspect_ca_key, + ] + + for file_path in required_files: + if not os.path.exists(file_path): + raise FileNotFoundError(f"SSL file not found: {file_path}") + if not os.path.isfile(file_path): + raise ValueError(f"Invalid SSL file: {file_path} is not a file") + + try: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain( + certfile=self.ssl_config.inspect_ca_cert, + keyfile=self.ssl_config.inspect_ca_key, + ) + except ssl.SSLError as e: + raise ssl.SSLError(f"SSL certificate/key validation failed: {e}") + + def _start_monitoring_server(self): + """ + Start monitoring flask server. + """ + flask_thread = threading.Thread( + target=start_flask_server, + args=( + self, + self.monitoring_config.flask_port, + self.monitoring_config.flask_pass, + self.debug, + ), + daemon=True, + ) + flask_thread.start() + def start(self): """ Start the proxy server and listen for incoming client connections. @@ -240,19 +269,8 @@ def start(self): self.console_logger.debug("[*] %s = %s", key, getattr(self, key)) if self.ssl_config.ssl_inspect: - if not self.ssl_config.inspect_ca_cert or not os.path.isfile( - self.ssl_config.inspect_ca_cert - ): - raise FileNotFoundError( - f"CA certificate not found: {self.ssl_config.inspect_ca_cert}" - ) - if not self.ssl_config.inspect_ca_key or not os.path.isfile( - self.ssl_config.inspect_ca_key - ): - raise FileNotFoundError( - f"CA key not found: {self.ssl_config.inspect_ca_key}" - ) os.makedirs(self.ssl_config.inspect_certs_folder, exist_ok=True) + self._validate_ssl_inspection_files() self._clean_inspection_folder() if self.filter_config.filter_mode == "local": @@ -268,12 +286,7 @@ def start(self): self._load_authorized_ips() if not __slim__: - flask_thread = threading.Thread( - target=start_flask_server, - args=(self, self.flask_port, self.flask_pass, self.debug), - daemon=True, - ) - flask_thread.start() + self._start_monitoring_server() self.console_logger.debug("[*] Starting the monitoring process...") server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -292,7 +305,23 @@ def start(self): self.console_logger.debug( "Unauthorized IP blocked: %s", client_ip ) - client_socket.close() + with open(self.html_403, "r", encoding="utf-8") as f: + custom_403_page = f.read() + response = ( + "HTTP/1.1 403 Forbidden\r\n" + "Content-Type: text/html; charset=utf-8\r\n" + "Connection: close\r\n\r\n" + f"{custom_403_page}" + ) + + try: + client_socket.sendall(response.encode("utf-8")) + except Exception as e: + self.console_logger.error( + "Error sending 403 response: %s", e + ) + finally: + client_socket.close() continue self.console_logger.debug("Connection from %s", addr) @@ -312,16 +341,13 @@ def start(self): console_logger=self.console_logger, shortcuts=self.config_shortcuts, custom_header=self.config_custom_header, - proxy_enable=self.proxy_enable, - proxy_host=self.proxy_host, - proxy_port=self.proxy_port, + proxy_config=self.proxy_config, active_connections=self.active_connections, ) client_handler = threading.Thread( target=client.handle_client, args=(client_socket,), daemon=True ) client_handler.start() - client_ip, client_port = addr self.active_connections[client_handler.ident] = { "client_ip": client_ip, "client_port": client_port, diff --git a/pyproxy/utils/config.py b/pyproxy/utils/config.py index 58ca952..11a97e7 100644 --- a/pyproxy/utils/config.py +++ b/pyproxy/utils/config.py @@ -4,121 +4,97 @@ This module defines configuration classes used by the HTTP/HTTPS proxy. """ +from dataclasses import dataclass, asdict + +@dataclass(frozen=True) +class ProxyConfigMain: + """ + Handles main configuration for the proxy. + """ + + host: str + port: int + debug: bool + html_403: str + shortcuts: str + custom_header: str + authorized_ips: str + + def to_dict(self): + return asdict(self) + + +@dataclass(frozen=True) +class ProxyConfigProxy: + """ + Handles proxy configuration for the proxy. + """ + + enable: bool + host: str + port: int + + def to_dict(self): + return asdict(self) + + +@dataclass(frozen=True) +class ProxyConfigMonitoring: + """ + Handles monitoring configuration for the proxy. + """ + + flask_port: int + flask_pass: str + + def to_dict(self): + return asdict(self) + + +@dataclass() class ProxyConfigLogger: """ Handles logging configuration for the proxy. """ - def __init__( - self, - access_log, - block_log, - no_logging_access, - no_logging_block, - console_format, - datefmt, - ): - self.access_log = access_log - self.block_log = block_log - self.access_logger = None - self.block_logger = None - self.no_logging_access = no_logging_access - self.no_logging_block = no_logging_block - self.console_format = console_format - self.datefmt = datefmt - - def __repr__(self): - return ( - f"ProxyConfigLogger(access_log={self.access_log}, " - f"block_log={self.block_log}, " - f"no_logging_access={self.no_logging_access}, " - f"no_logging_block={self.no_logging_block}), " - f"console_format={self.console_format}), " - f"datefmt={self.datefmt})" - ) + access_log: str + block_log: str + no_logging_access: bool + no_logging_block: bool + console_format: str + datefmt: str def to_dict(self): - """ - Converts the ProxyConfigLogger instance into a dictionary. - """ - return { - "access_log": self.access_log, - "block_log": self.block_log, - "no_logging_access": self.no_logging_access, - "no_logging_block": self.no_logging_block, - "console_format": self.console_format, - "datefmt": self.datefmt, - } + return asdict(self) +@dataclass(frozen=True) class ProxyConfigFilter: """ Manages filtering configuration for the proxy. """ - def __init__(self, no_filter, filter_mode, blocked_sites, blocked_url): - self.no_filter = no_filter - self.filter_mode = filter_mode - self.blocked_sites = blocked_sites - self.blocked_url = blocked_url - - def __repr__(self): - return ( - f"ProxyConfigFilter(no_filter={self.no_filter}, " - f"filter_mode='{self.filter_mode}', " - f"blocked_sites={self.blocked_sites}, " - f"blocked_url={self.blocked_url})" - ) + no_filter: bool + filter_mode: str + blocked_sites: str + blocked_url: str def to_dict(self): - """ - Converts the ProxyConfigFilter instance into a dictionary. - """ - return { - "no_filter": self.no_filter, - "filter_mode": self.filter_mode, - "blocked_sites": self.blocked_sites, - "blocked_url": self.blocked_url, - } + return asdict(self) +@dataclass(frozen=True) class ProxyConfigSSL: """ Handles SSL/TLS inspection configuration. """ - def __init__( - self, - ssl_inspect, - inspect_ca_cert, - inspect_ca_key, - inspect_certs_folder, - cancel_inspect, - ): - self.ssl_inspect = ssl_inspect - self.inspect_ca_cert = inspect_ca_cert - self.inspect_ca_key = inspect_ca_key - self.inspect_certs_folder = inspect_certs_folder - self.cancel_inspect = cancel_inspect - - def __repr__(self): - return ( - f"ProxyConfigSSL(ssl_inspect={self.ssl_inspect}, " - f"inspect_ca_cert='{self.inspect_ca_cert}', " - f"inspect_ca_key='{self.inspect_ca_key}', " - f"inspect_certs_folder='{self.inspect_certs_folder}', " - f"cancel_inspect={self.cancel_inspect})" - ) + ssl_inspect: bool + inspect_ca_cert: str + inspect_ca_key: str + inspect_certs_folder: str + cancel_inspect: str def to_dict(self): - """ - Converts the ProxyConfigSSL instance into a dictionary. - """ - return { - "ssl_inspect": self.ssl_inspect, - "inspect_ca_cert": self.inspect_ca_cert, - "inspect_ca_key": self.inspect_ca_key, - "inspect_certs_folder": self.inspect_certs_folder, - "cancel_inspect": self.cancel_inspect, - } + return asdict(self) diff --git a/pyproxy/utils/http_req.py b/pyproxy/utils/http_req.py index 4601763..c2d0d5f 100644 --- a/pyproxy/utils/http_req.py +++ b/pyproxy/utils/http_req.py @@ -22,32 +22,3 @@ def extract_headers(request_str): key, value = line.split(":", 1) headers[key.strip()] = value.strip() return headers - - -def parse_url(url): - """ - Parses the URL to extract the host and port for connecting to the target server. - - Args: - url (str): The URL to be parsed. - - Returns: - tuple: The server host and port. - """ - http_pos = url.find("//") - if http_pos != -1: - url = url[(http_pos + 2) :] - port_pos = url.find(":") - path_pos = url.find("/") - if path_pos == -1: - path_pos = len(url) - - server_host = ( - url[:path_pos] if port_pos == -1 or port_pos > path_pos else url[:port_pos] - ) - if port_pos == -1 or port_pos > path_pos: - server_port = 80 - else: - server_port = int(url[(port_pos + 1) : path_pos]) - - return server_host, server_port diff --git a/pyproxy/utils/logger.py b/pyproxy/utils/logger.py index 7942fa0..d2d425d 100644 --- a/pyproxy/utils/logger.py +++ b/pyproxy/utils/logger.py @@ -16,7 +16,6 @@ def configure_console_logger(logger_config) -> logging.Logger: Returns: logging.Logger: A logger instance that writes logs to the console. """ - print(logger_config) console_logger = logging.getLogger("ConsoleLogger") console_logger.setLevel(logging.INFO) diff --git a/pyproxy/utils/version.py b/pyproxy/utils/version.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/utils/test_http_req.py b/tests/utils/test_http_req.py index b9ef2fc..5f844e5 100644 --- a/tests/utils/test_http_req.py +++ b/tests/utils/test_http_req.py @@ -5,7 +5,7 @@ """ import unittest -from pyproxy.utils.http_req import extract_headers, parse_url +from pyproxy.utils.http_req import extract_headers class TestHttpReq(unittest.TestCase): @@ -34,36 +34,6 @@ def test_extract_headers(self): headers = extract_headers(request_str) self.assertEqual(headers, expected_headers) - def test_parse_url(self): - """ - Test the `parse_url` function to ensure it correctly extracts the host and port - from a URL. - """ - - url = "http://example.com:8080/path/to/resource" - expected_host, expected_port = "example.com", 8080 - host, port = parse_url(url) - self.assertEqual(host, expected_host) - self.assertEqual(port, expected_port) - - url = "http://example.com/path/to/resource" - expected_host, expected_port = "example.com", 80 - host, port = parse_url(url) - self.assertEqual(host, expected_host) - self.assertEqual(port, expected_port) - - url = "example.com:9090" - expected_host, expected_port = "example.com", 9090 - host, port = parse_url(url) - self.assertEqual(host, expected_host) - self.assertEqual(port, expected_port) - - url = "example.com" - expected_host, expected_port = "example.com", 80 - host, port = parse_url(url) - self.assertEqual(host, expected_host) - self.assertEqual(port, expected_port) - if __name__ == "__main__": unittest.main()