diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 54fd893..03cf04a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,4 +26,4 @@ jobs: uses: actions/checkout@v4 - name: Spec - run: crystal spec + run: crystal spec -Dpreview_mt -Dexecution_context diff --git a/README.md b/README.md index 3f04ed2..6a68d62 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ # SystemD -SystemD integration for Crystal applications, can notify systemd, get socket listeners and store/restore file descriptors. libsystemd is only required for storing FDs. +SystemD integration for Crystal applications, can notify systemd, get socket listeners, store/restore file descriptors, and monitor memory pressure. libsystemd is only required for storing FDs. Man pages: https://man7.org/linux/man-pages/man3/sd_pid_notify.3.html https://man7.org/linux/man-pages/man3/sd_listen_fds.3.html +https://systemd.io/MEMORY_PRESSURE/ ## Installation @@ -41,6 +42,14 @@ end # Enable systemd watchdog support with `WatchdogSec=5` under `[Service]` SystemD.watchdog +# Monitor memory pressure notifications from systemd +# Enable with `MemoryPressureWatch=auto` and `MemoryPressureThresholdSec=1s` under `[Service]` +SystemD::MemoryPressure.monitor do + # Called when memory pressure is detected + # Take action like clearing caches, reducing memory usage, etc. + clear_caches +end + # Store FDs with the SystemD, they will be sent back # to the application when it restarts. Requires libsystemd clients = Array(TCPSocket).new diff --git a/spec/memory_pressure_spec.cr b/spec/memory_pressure_spec.cr new file mode 100644 index 0000000..edca2a7 --- /dev/null +++ b/spec/memory_pressure_spec.cr @@ -0,0 +1,155 @@ +require "./spec_helper" +require "../src/memory_pressure" + +describe SystemD::MemoryPressure do + it "does nothing when MEMORY_PRESSURE_WATCH is not set" do + ENV.delete("MEMORY_PRESSURE_WATCH") + ENV.delete("MEMORY_PRESSURE_WRITE") + + called = false + SystemD::MemoryPressure.monitor { called = true } + + sleep 0.1.seconds + called.should be_false + end + + it "does nothing when MEMORY_PRESSURE_WATCH is /dev/null" do + ENV["MEMORY_PRESSURE_WATCH"] = "/dev/null" + ENV.delete("MEMORY_PRESSURE_WRITE") + + called = false + SystemD::MemoryPressure.monitor { called = true } + + sleep 0.1.seconds + called.should be_false + end + + it "monitors memory pressure on a FIFO" do + fifo_path = File.tempname + begin + # Create a FIFO + ret = LibC.mkfifo(fifo_path, 0o600) + raise IO::Error.from_errno("mkfifo failed") if ret != 0 + + ENV["MEMORY_PRESSURE_WATCH"] = fifo_path + ENV.delete("MEMORY_PRESSURE_WRITE") + + wg = WaitGroup.new(1) + SystemD::MemoryPressure.monitor { wg.done } + + # Write to the FIFO to trigger memory pressure + File.open(fifo_path, "w") do |f| + f.sync = true + f.print "pressure" + end + + # Wait for the callback to be called + wg.wait + ensure + File.delete(fifo_path) if File.exists?(fifo_path) + ENV.delete("MEMORY_PRESSURE_WATCH") + end + end + + it "monitors memory pressure on a Unix socket" do + socket_path = File.tempname + begin + # Create a Unix socket server + server = UNIXServer.new(socket_path) + + ENV["MEMORY_PRESSURE_WATCH"] = socket_path + ENV.delete("MEMORY_PRESSURE_WRITE") + + ch = Channel(Nil).new + SystemD::MemoryPressure.monitor { ch.send(nil) } + + # Accept the connection and send data + client = server.accept + client.print "pressure" + client.flush + + # Wait for the callback to be called + ch.receive + + client.close + server.close + ensure + File.delete(socket_path) if File.exists?(socket_path) + ENV.delete("MEMORY_PRESSURE_WATCH") + end + end + + it "writes threshold data when MEMORY_PRESSURE_WRITE is set" do + socket_path = File.tempname + begin + server = UNIXServer.new(socket_path) + + # Set up both environment variables + write_data = "some threshold" + ENV["MEMORY_PRESSURE_WATCH"] = socket_path + ENV["MEMORY_PRESSURE_WRITE"] = Base64.strict_encode(write_data) + + ch = Channel(Nil).new + SystemD::MemoryPressure.monitor { ch.send(nil) } + + # Accept the connection and read the threshold data + client = server.accept + buffer = uninitialized UInt8[4096] + count = client.read(buffer.to_slice) + received = String.new(buffer.to_unsafe, count) + received.should eq write_data + + # Now send pressure notification + client.print "pressure" + client.flush + + # Wait for the callback to be called + ch.receive + + client.close + server.close + ensure + File.delete(socket_path) if File.exists?(socket_path) + ENV.delete("MEMORY_PRESSURE_WATCH") + ENV.delete("MEMORY_PRESSURE_WRITE") + end + end + + it "handles socket reconnection" do + socket_path = File.tempname + begin + server = UNIXServer.new(socket_path) + + ENV["MEMORY_PRESSURE_WATCH"] = socket_path + ENV.delete("MEMORY_PRESSURE_WRITE") + + call_count = 0 + SystemD::MemoryPressure.monitor { call_count += 1 } + + # Accept first connection and trigger pressure + client1 = server.accept + client1.print "pressure1" + + # Close the connection to force reconnection + client1.close + + # Accept second connection and trigger pressure again + client2 = server.accept + client2.print "pressure2" + + # Wait a bit for callbacks + timeout = Time.monotonic + 2.seconds + until call_count >= 2 || Time.monotonic > timeout + Fiber.yield + end + + call_count.should be >= 2 + + client2.close + server.close + ensure + File.delete(socket_path) if File.exists?(socket_path) + ENV.delete("MEMORY_PRESSURE_WATCH") + end + end +end diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr index 8fa8e76..4dae349 100644 --- a/spec/spec_helper.cr +++ b/spec/spec_helper.cr @@ -1,5 +1,6 @@ require "spec" require "socket" +require "wait_group" require "../src/systemd" ENV["LISTEN_PID"] = Process.pid.to_s diff --git a/src/memory_pressure.cr b/src/memory_pressure.cr new file mode 100644 index 0000000..32a83de --- /dev/null +++ b/src/memory_pressure.cr @@ -0,0 +1,214 @@ +require "base64" +require "log" + +lib LibC + struct PollFD + fd : Int + events : Short + revents : Short + end + + POLLIN = 0x0001 + POLLPRI = 0x0002 + + fun poll(fds : PollFD*, nfds : UInt, timeout : Int) : Int +end + +module SystemD + # Module to monitor memory pressure using systemd's memory pressure notification mechanism. + module MemoryPressure + Log = ::Log.for("systemd.memory_pressure") + + # The block is called when memory pressure is detected + def self.monitor(&block : ->) + watch_path = ENV["MEMORY_PRESSURE_WATCH"]? + return unless watch_path + + if watch_path == "/dev/null" + Log.info { "Memory pressure monitoring disabled" } + return + end + + Fiber::ExecutionContext::Isolated.new("Memory Pressure Monitor") do + begin + monitor_internal(watch_path, &block) + rescue ex + Log.error(exception: ex) { "Memory pressure monitoring failed" } + end + end + end + + private def self.monitor_internal(watch_path : String, &block : ->) + write_data = decode_write_data + file_type = determine_file_type(watch_path) + + case file_type + when :regular + monitor_regular_file(watch_path, write_data, &block) + when :fifo + monitor_fifo(watch_path, write_data, &block) + when :socket + monitor_socket(watch_path, write_data, &block) + else + Log.warn { "Unknown file type for #{watch_path}, attempting as regular file" } + monitor_regular_file(watch_path, write_data, &block) + end + end + + private def self.decode_write_data : Bytes? + if encoded = ENV["MEMORY_PRESSURE_WRITE"]? + Base64.decode(encoded) + end + end + + private def self.determine_file_type(path : String) : Symbol + result = LibC.stat(path, out stat) + raise IO::Error.from_errno("stat failed") if result != 0 + + file_mode = stat.st_mode & LibC::S_IFMT + case file_mode + when LibC::S_IFREG + :regular + when LibC::S_IFIFO + :fifo + when LibC::S_IFSOCK + :socket + else + :unknown + end + end + + private def self.monitor_regular_file(path : String, write_data : Bytes?, &block : ->) + Log.info { "Monitoring memory pressure on regular file: #{path}" } + + fd = LibC.open(path, LibC::O_RDWR) + raise IO::Error.from_errno("open failed") if fd < 0 + + begin + # Write the pressure threshold data if provided + if write_data + written = LibC.write(fd, write_data, write_data.size) + raise IO::Error.from_errno("write failed") if written < 0 + end + + poll_fd = LibC::PollFD.new + poll_fd.fd = fd + poll_fd.events = LibC::POLLPRI + + loop do + result = LibC.poll(pointerof(poll_fd), 1, -1) # -1 = infinite timeout + if result < 0 + next if Errno.value == Errno::EINTR + raise IO::Error.from_errno("poll failed") + elsif result > 0 && (poll_fd.revents & LibC::POLLPRI) != 0 + handle_memory_pressure(&block) + # For regular files, we don't read from the FD + end + end + ensure + LibC.close(fd) + end + end + + private def self.monitor_fifo(path : String, write_data : Bytes?, &block : ->) + Log.info { "Monitoring memory pressure on FIFO: #{path}" } + + fd = LibC.open(path, LibC::O_RDWR) + raise IO::Error.from_errno("open failed") if fd < 0 + + begin + # Write the pressure threshold data if provided + if write_data + written = LibC.write(fd, write_data, write_data.size) + raise IO::Error.from_errno("write failed") if written < 0 + end + + poll_fd = LibC::PollFD.new + poll_fd.fd = fd + poll_fd.events = LibC::POLLIN + + loop do + result = LibC.poll(pointerof(poll_fd), 1, -1) + if result < 0 + next if Errno.value == Errno::EINTR + raise IO::Error.from_errno("poll failed") + end + if result > 0 && (poll_fd.revents & LibC::POLLIN) != 0 + handle_memory_pressure(&block) + # Read and discard data from FIFO + buf = uninitialized UInt8[4096] + bytes_read = LibC.read(fd, buf, 4096) + # EOF (0) or error is expected, continue polling + end + end + ensure + LibC.close(fd) + end + end + + private def self.monitor_socket(path : String, write_data : Bytes?, &block : ->) + Log.info { "Monitoring memory pressure on Unix socket: #{path}" } + + fd = connect_unix_socket(path) + + begin + # Write the pressure threshold data if provided + if write_data + written = LibC.write(fd, write_data, write_data.size) + raise IO::Error.from_errno("write failed") if written < 0 + end + + poll_fd = LibC::PollFD.new + poll_fd.fd = fd + poll_fd.events = LibC::POLLIN + + loop do + result = LibC.poll(pointerof(poll_fd), 1, -1) + if result < 0 + next if Errno.value == Errno::EINTR + raise IO::Error.from_errno("poll failed") + end + if result > 0 && (poll_fd.revents & LibC::POLLIN) != 0 + handle_memory_pressure(&block) + # Read and discard data from socket + buffer = uninitialized UInt8[4096] + bytes_read = LibC.read(fd, buffer, 4096) + if bytes_read <= 0 + # Connection closed, reconnect + LibC.close(fd) + fd = connect_unix_socket(path) + poll_fd.fd = fd + if write_data + written = LibC.write(fd, write_data, write_data.size) + raise IO::Error.from_errno("write failed after reconnect") if written < 0 + end + end + end + end + ensure + LibC.close(fd) + end + end + + private def self.connect_unix_socket(path : String) : Int32 + fd = LibC.socket(LibC::AF_UNIX, LibC::SOCK_STREAM, 0) + raise IO::Error.from_errno("socket creation failed") if fd < 0 + + sockaddr = Pointer(LibC::SockaddrUn).malloc + sockaddr.value.sun_family = LibC::AF_UNIX.to_u16 + sockaddr.value.sun_path.to_unsafe.copy_from(path.to_unsafe, {path.bytesize + 1, sockaddr.value.sun_path.size}.min) + + if LibC.connect(fd, sockaddr.as(LibC::Sockaddr*), sizeof(LibC::SockaddrUn)) < 0 + LibC.close(fd) + raise IO::Error.from_errno("connect failed") + end + + fd + end + + private def self.handle_memory_pressure(&block : ->) + Log.info { "Memory pressure detected" } + block.call + end + end +end diff --git a/src/systemd.cr b/src/systemd.cr index 79a79aa..75cb8d0 100644 --- a/src/systemd.cr +++ b/src/systemd.cr @@ -2,6 +2,7 @@ require "socket" {% if flag?(:linux) %} require "./libsystemd" {% end %} +require "./memory_pressure" # Wrapper for libsystemd # http://man7.org/linux/man-pages/man3/sd_pid_notify_with_fds.3.html