Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sandlock/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ def _notif_syscall_names(notif: "NotifPolicy") -> list[str]:
from ._seccomp import _SYSCALL_NR
names = []

# openat/open needed when features require path inspection.
# openat/open only needed when features require path inspection
needs_openat = (
notif is not None and (
notif.rules # path-based virtualization (e.g. /etc/hosts)
or notif.isolate_pids # block /proc/<foreign_pid> access
or notif.cow_enabled # COW filesystem redirects
or notif.random_seed is not None # deterministic /dev/urandom
or notif.time_start is not None # /proc/uptime, /proc/stat virtualization
or notif.port_remap # /proc/net/* filtering
or notif.isolate_pids # /proc/<pid> access control
)
)
if needs_openat:
Expand Down
3 changes: 3 additions & 0 deletions src/sandlock/_notif.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,9 @@ def stop(self) -> None:
pass
self._mem_fd = -1
self._mem_fd_pid = -1
if self._port_map is not None:
self._port_map.close()
self._port_map = None

def _check_disk_quota(self) -> None:
"""Check if overlay upper dir exceeds disk quota."""
Expand Down
28 changes: 10 additions & 18 deletions src/sandlock/_port_remap.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class PortMap:
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
_virtual_to_real: dict[int, int] = field(default_factory=dict, repr=False)
_real_to_virtual: dict[int, int] = field(default_factory=dict, repr=False)
# Sockets held open to keep real ports reserved
_held_sockets: list[socket.socket] = field(default_factory=list, repr=False)
# Proxy state
_proxy_threads: list[threading.Thread] = field(default_factory=list, repr=False)
_proxy_sockets: list[socket.socket] = field(default_factory=list, repr=False)
Expand Down Expand Up @@ -96,7 +94,7 @@ def virtual_port(self, real: int) -> int | None:
return self._real_to_virtual.get(real)

def close(self) -> None:
"""Release all held sockets, stop proxies."""
"""Stop proxies and release all state."""
self._proxy_stop.set()
for s in self._proxy_sockets:
try:
Expand All @@ -106,12 +104,6 @@ def close(self) -> None:
for t in self._proxy_threads:
t.join(timeout=2.0)
with self._lock:
for s in self._held_sockets:
try:
s.close()
except OSError:
pass
self._held_sockets.clear()
self._virtual_to_real.clear()
self._real_to_virtual.clear()
self._proxy_sockets.clear()
Expand All @@ -138,17 +130,16 @@ def _try_reserve_port(self, port: int, family: int) -> int | None:

def _allocate_real_port(self, family: int) -> int | None:
"""Bind a socket to port 0 to get a free port from the kernel."""
af = socket.AF_INET6 if family == _AF_INET6 else socket.AF_INET
addr = "::1" if af == socket.AF_INET6 else "127.0.0.1"
s = socket.socket(af, socket.SOCK_STREAM)
try:
af = socket.AF_INET6 if family == _AF_INET6 else socket.AF_INET
s = socket.socket(af, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
addr = "::1" if af == socket.AF_INET6 else "127.0.0.1"
s.bind((addr, 0))
real_port = s.getsockname()[1]
self._held_sockets.append(s)
return real_port
return s.getsockname()[1]
except OSError:
return None
finally:
s.close()

def _start_proxy(self, virtual: int, real: int, family: int) -> None:
"""Start a TCP proxy: listen on virtual port, forward to real port."""
Expand Down Expand Up @@ -363,8 +354,9 @@ def fixup_getsockname(pid: int, sockaddr_addr: int, addrlen_addr: int,
from ._procfs import write_bytes

# Duplicate the child's socket fd via pidfd_getfd syscall
from ._context import _pidfd_open
try:
pidfd = os.pidfd_open(pid)
pidfd = _pidfd_open(pid)
except OSError:
return False

Expand All @@ -387,8 +379,8 @@ def fixup_getsockname(pid: int, sockaddr_addr: int, addrlen_addr: int,
family = s.family
finally:
s.detach()
os.close(local_fd)
except OSError:
os.close(local_fd)
return False

if family not in (socket.AF_INET, socket.AF_INET6):
Expand Down
74 changes: 72 additions & 2 deletions tests/test_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Tests for sandlock.sandbox.Sandbox."""

import os
import socket
import sys
from unittest.mock import patch, MagicMock

Expand Down Expand Up @@ -322,7 +323,7 @@ def test_proc_net_tcp_shows_own_port_only(self):
result = Sandbox(policy).run(["python3", "-c", code])

assert result.success
assert int(result.stdout.strip()) >= 1
assert int(result.stdout.strip()) == 1

def test_proc_net_tcp_hides_host_ports(self):
"""/proc/net/tcp hides host ports (e.g. sshd on port 22)."""
Expand Down Expand Up @@ -372,7 +373,7 @@ def test_proc_net_tcp6_filtered(self):
result = Sandbox(policy).run(["python3", "-c", code])

assert result.success
assert int(result.stdout.strip()) >= 1
assert int(result.stdout.strip()) == 1

def test_tcp_sendmsg_2mb_with_port_remap(self):
"""TCP sendmsg() with 2 MB payload works correctly under port remap."""
Expand Down Expand Up @@ -443,6 +444,75 @@ def test_tcp_sendmsg_2mb_with_port_remap(self):
assert data["data_ok"] is True


def test_slow_path_host_holds_virtual_port(self):
"""Slow path: host process holds TCP virtual port, sandbox must remap."""
import socket as _socket
code = (
"import socket; "
"s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); "
"s.bind(('127.0.0.1', 8080)); "
"print(s.getsockname()[1]); "
"s.close()"
)
policy = Policy(port_remap=True)

holder = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM)
holder.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, 1)
holder.bind(("127.0.0.1", 8080))
try:
result = Sandbox(policy).run(["python3", "-c", code])
finally:
holder.close()

assert result.success, f"Sandbox failed: {result.stderr}"
assert result.stdout.strip() == b"8080"

def test_slow_path_two_concurrent_sandboxes(self):
"""Slow path: two concurrent sandboxes both bind the same virtual TCP port."""
import threading
code_hold = (
"import socket, time; "
"s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); "
"s.bind(('127.0.0.1', 8080)); "
"print(s.getsockname()[1], flush=True); "
"time.sleep(3); "
"s.close()"
)
code_fast = (
"import socket; "
"s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); "
"s.bind(('127.0.0.1', 8080)); "
"print(s.getsockname()[1]); "
"s.close()"
)
policy = Policy(port_remap=True)
results = [None, None]

def run(i, code):
results[i] = Sandbox(policy).run(["python3", "-c", code])

t1 = threading.Thread(target=run, args=(0, code_hold))
t1.start()
import time
for _ in range(50): # wait for sandbox 1 to bind port 8080
probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
probe.bind(("127.0.0.1", 8080))
probe.close()
time.sleep(0.1)
except OSError:
break # sandbox 1 is ready
t2 = threading.Thread(target=run, args=(1, code_fast))
t2.start()
t1.join()
t2.join()

r1, r2 = results
assert r1.success, f"Sandbox 1 failed: {r1.stderr}"
assert r2.success, f"Sandbox 2 failed: {r2.stderr}"
assert r1.stdout.strip() == b"8080"
assert r2.stdout.strip() == b"8080"

class TestCpuThrottle:
"""Test SIGSTOP/SIGCONT CPU throttling."""

Expand Down