diff --git a/src/sandlock/_context.py b/src/sandlock/_context.py index 22cf8bd..4ee8d12 100644 --- a/src/sandlock/_context.py +++ b/src/sandlock/_context.py @@ -125,15 +125,15 @@ def _notif_syscall_names(notif: "NotifPolicy") -> list[str]: from ._seccomp import _SYSCALL_NR names = [] - # openat/open needed when features require path inspection. + # openat/open only needed when features require path inspection needs_openat = ( notif is not None and ( notif.rules # path-based virtualization (e.g. /etc/hosts) - or notif.isolate_pids # block /proc/ access or notif.cow_enabled # COW filesystem redirects or notif.random_seed is not None # deterministic /dev/urandom or notif.time_start is not None # /proc/uptime, /proc/stat virtualization or notif.port_remap # /proc/net/* filtering + or notif.isolate_pids # /proc/ access control ) ) if needs_openat: diff --git a/src/sandlock/_notif.py b/src/sandlock/_notif.py index 129f5b7..a782a2b 100644 --- a/src/sandlock/_notif.py +++ b/src/sandlock/_notif.py @@ -528,6 +528,9 @@ def stop(self) -> None: pass self._mem_fd = -1 self._mem_fd_pid = -1 + if self._port_map is not None: + self._port_map.close() + self._port_map = None def _check_disk_quota(self) -> None: """Check if overlay upper dir exceeds disk quota.""" diff --git a/src/sandlock/_port_remap.py b/src/sandlock/_port_remap.py index 6119e00..f536393 100644 --- a/src/sandlock/_port_remap.py +++ b/src/sandlock/_port_remap.py @@ -47,8 +47,6 @@ class PortMap: _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) _virtual_to_real: dict[int, int] = field(default_factory=dict, repr=False) _real_to_virtual: dict[int, int] = field(default_factory=dict, repr=False) - # Sockets held open to keep real ports reserved - _held_sockets: list[socket.socket] = field(default_factory=list, repr=False) # Proxy state _proxy_threads: list[threading.Thread] = field(default_factory=list, repr=False) _proxy_sockets: list[socket.socket] = field(default_factory=list, repr=False) @@ -96,7 +94,7 @@ def virtual_port(self, real: int) -> int | None: return self._real_to_virtual.get(real) def close(self) -> None: - """Release all held sockets, stop proxies.""" + """Stop proxies and release all state.""" self._proxy_stop.set() for s in self._proxy_sockets: try: @@ -106,12 +104,6 @@ def close(self) -> None: for t in self._proxy_threads: t.join(timeout=2.0) with self._lock: - for s in self._held_sockets: - try: - s.close() - except OSError: - pass - self._held_sockets.clear() self._virtual_to_real.clear() self._real_to_virtual.clear() self._proxy_sockets.clear() @@ -138,17 +130,16 @@ def _try_reserve_port(self, port: int, family: int) -> int | None: def _allocate_real_port(self, family: int) -> int | None: """Bind a socket to port 0 to get a free port from the kernel.""" + af = socket.AF_INET6 if family == _AF_INET6 else socket.AF_INET + addr = "::1" if af == socket.AF_INET6 else "127.0.0.1" + s = socket.socket(af, socket.SOCK_STREAM) try: - af = socket.AF_INET6 if family == _AF_INET6 else socket.AF_INET - s = socket.socket(af, socket.SOCK_STREAM) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - addr = "::1" if af == socket.AF_INET6 else "127.0.0.1" s.bind((addr, 0)) - real_port = s.getsockname()[1] - self._held_sockets.append(s) - return real_port + return s.getsockname()[1] except OSError: return None + finally: + s.close() def _start_proxy(self, virtual: int, real: int, family: int) -> None: """Start a TCP proxy: listen on virtual port, forward to real port.""" @@ -363,8 +354,9 @@ def fixup_getsockname(pid: int, sockaddr_addr: int, addrlen_addr: int, from ._procfs import write_bytes # Duplicate the child's socket fd via pidfd_getfd syscall + from ._context import _pidfd_open try: - pidfd = os.pidfd_open(pid) + pidfd = _pidfd_open(pid) except OSError: return False @@ -387,8 +379,8 @@ def fixup_getsockname(pid: int, sockaddr_addr: int, addrlen_addr: int, family = s.family finally: s.detach() + os.close(local_fd) except OSError: - os.close(local_fd) return False if family not in (socket.AF_INET, socket.AF_INET6): diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index f288c78..5b1e4ce 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -2,6 +2,7 @@ """Tests for sandlock.sandbox.Sandbox.""" import os +import socket import sys from unittest.mock import patch, MagicMock @@ -322,7 +323,7 @@ def test_proc_net_tcp_shows_own_port_only(self): result = Sandbox(policy).run(["python3", "-c", code]) assert result.success - assert int(result.stdout.strip()) >= 1 + assert int(result.stdout.strip()) == 1 def test_proc_net_tcp_hides_host_ports(self): """/proc/net/tcp hides host ports (e.g. sshd on port 22).""" @@ -372,7 +373,7 @@ def test_proc_net_tcp6_filtered(self): result = Sandbox(policy).run(["python3", "-c", code]) assert result.success - assert int(result.stdout.strip()) >= 1 + assert int(result.stdout.strip()) == 1 def test_tcp_sendmsg_2mb_with_port_remap(self): """TCP sendmsg() with 2 MB payload works correctly under port remap.""" @@ -443,6 +444,75 @@ def test_tcp_sendmsg_2mb_with_port_remap(self): assert data["data_ok"] is True + def test_slow_path_host_holds_virtual_port(self): + """Slow path: host process holds TCP virtual port, sandbox must remap.""" + import socket as _socket + code = ( + "import socket; " + "s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); " + "s.bind(('127.0.0.1', 8080)); " + "print(s.getsockname()[1]); " + "s.close()" + ) + policy = Policy(port_remap=True) + + holder = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) + holder.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, 1) + holder.bind(("127.0.0.1", 8080)) + try: + result = Sandbox(policy).run(["python3", "-c", code]) + finally: + holder.close() + + assert result.success, f"Sandbox failed: {result.stderr}" + assert result.stdout.strip() == b"8080" + + def test_slow_path_two_concurrent_sandboxes(self): + """Slow path: two concurrent sandboxes both bind the same virtual TCP port.""" + import threading + code_hold = ( + "import socket, time; " + "s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); " + "s.bind(('127.0.0.1', 8080)); " + "print(s.getsockname()[1], flush=True); " + "time.sleep(3); " + "s.close()" + ) + code_fast = ( + "import socket; " + "s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); " + "s.bind(('127.0.0.1', 8080)); " + "print(s.getsockname()[1]); " + "s.close()" + ) + policy = Policy(port_remap=True) + results = [None, None] + + def run(i, code): + results[i] = Sandbox(policy).run(["python3", "-c", code]) + + t1 = threading.Thread(target=run, args=(0, code_hold)) + t1.start() + import time + for _ in range(50): # wait for sandbox 1 to bind port 8080 + probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + probe.bind(("127.0.0.1", 8080)) + probe.close() + time.sleep(0.1) + except OSError: + break # sandbox 1 is ready + t2 = threading.Thread(target=run, args=(1, code_fast)) + t2.start() + t1.join() + t2.join() + + r1, r2 = results + assert r1.success, f"Sandbox 1 failed: {r1.stderr}" + assert r2.success, f"Sandbox 2 failed: {r2.stderr}" + assert r1.stdout.strip() == b"8080" + assert r2.stdout.strip() == b"8080" + class TestCpuThrottle: """Test SIGSTOP/SIGCONT CPU throttling."""