diff --git a/dataloader.zig b/dataloader.zig index 38a475b..e71dc1f 100644 --- a/dataloader.zig +++ b/dataloader.zig @@ -8,14 +8,14 @@ const logger = std.log.scoped(.dataloader); const wlog = std.log.scoped(.dataloader_io_thread); pub const FileHandle = packed struct { - _: u14 = 0, idx: u20, generation: u20, path_checksum: u8, + _: u16 = 0, }; -const max_file_slots = std.math.maxInt(@FieldType(FileHandle, "idx")); -const max_generation = std.math.maxInt(@FieldType(FileHandle, "generation")); +const max_file_slots = std.math.maxInt(u20); +const max_generation = std.math.maxInt(u20); pub const ReadBlockReq = struct { base: u64, @@ -88,6 +88,7 @@ pub const LoaderCtx = struct { loop: xev.Loop, worker_thread: ?std.Thread = null, + worker_done: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), req_cnt: u64 = 0, is_running: bool = false, @@ -104,14 +105,10 @@ pub const LoaderCtx = struct { // This isn't called too often, just linear scan for (0..max_file_slots) |offset| { const i = (self.last_slot + offset) % max_file_slots; - const f = self.file_slots[i]; + const f = self.file_slots[0..][i]; if (f == null) { self.last_slot = (i + 1) % max_file_slots; - return .{ - .idx = @intCast(i), - .generation = 0, - .path_checksum = 0, - }; + return .{ .idx = @intCast(i), .generation = 0, .path_checksum = 0 }; } } @@ -119,10 +116,15 @@ pub const LoaderCtx = struct { } fn checkFilehandle(self: *Self, file: FileHandle) !void { - const slot: usize = @intCast(file.idx); - if (self.file_slots[slot] == null) return LoaderError.InvalidFileHandle; - if (self.file_slots_generation[slot] != file.generation or self.file_slots_checksum[slot] != file.path_checksum) { - logger.warn("File handle {} is corrupted, current generation: {}, checksum: {}", .{ file, self.file_slots_generation[slot], self.file_slots_checksum[slot] }); + const slot: usize = file.idx; + const file_slots = self.file_slots[0..]; + const generations = self.file_slots_generation[0..]; + const checksums = self.file_slots_checksum[0..]; + + if (slot >= file_slots.len) return LoaderError.InvalidFileHandle; + if (file_slots[slot] == null) return LoaderError.InvalidFileHandle; + if (generations[slot] != file.generation or checksums[slot] != file.path_checksum) { + logger.warn("File handle {} is corrupted, current generation: {}, checksum: {}", .{ file, generations[slot], checksums[slot] }); return LoaderError.InvalidFileHandle; } } @@ -153,12 +155,15 @@ pub const LoaderCtx = struct { const self = ud orelse unreachable; const xreq: *XevReq = @fieldParentPtr("c", c); + const slot: usize = xreq.req.file.idx; + const request_id = xreq.request_id; const actual = r catch { - self.sendResponseSynced(xreq.request_id, LoaderError.ReadError); + self.fileDecRef(slot); + self.req_mem_pool.destroy(xreq); + self.sendResponseSynced(request_id, LoaderError.ReadError); return .disarm; }; - const slot: usize = @intCast(xreq.req.file.idx); self.fileDecRef(slot); if (actual != xreq.req.result_buffer.len) { @@ -173,25 +178,24 @@ pub const LoaderCtx = struct { fn fileAddRef(self: *Self, slot: usize) void { // Should start on 1 - std.debug.assert(self.file_refcount[slot] > 0); - self.file_refcount[slot] += 1; + std.debug.assert(self.file_refcount[0..][slot] > 0); + self.file_refcount[0..][slot] += 1; } fn fileDecRef(self: *Self, slot: usize) void { - std.debug.assert(self.file_refcount[slot] > 0); - self.file_refcount[slot] -= 1; - if (self.file_refcount[slot] == 0) { - const f = self.file_slots[slot] orelse unreachable; + std.debug.assert(self.file_refcount[0..][slot] > 0); + self.file_refcount[0..][slot] -= 1; + if (self.file_refcount[0..][slot] == 0) { + const f = self.file_slots[0..][slot] orelse unreachable; f.close(); - self.file_slots[slot] = null; + self.file_slots[0..][slot] = null; } } fn handleReq(self: *Self, req_id: u64, req: Request) void { switch (req) { .open_file => |open_req| { - const file_path = open_req.file_path; - wlog.debug("Req {}: open_file: file = {s}", .{ req_id, file_path }); + wlog.debug("Req {}: open_file: file = {s}", .{ req_id, open_req.file_path }); var h = self.findFreeFileSlot() catch |err| { self.sendResponseSynced(req_id, err); @@ -206,16 +210,16 @@ pub const LoaderCtx = struct { const xf = xev.File.init(f) catch unreachable; // Commit state - const slot: usize = @intCast(h.idx); - h.path_checksum = path_checksum(file_path); + const slot: usize = h.idx; + const checksum = path_checksum(open_req.file_path); self.file_slots[slot] = f; self.xfile_slots[slot] = xf; self.file_refcount[slot] = 1; self.file_slots_generation[slot] += 1; // gen 0 is reserved to catch errors - self.file_slots_checksum[slot] = h.path_checksum; + self.file_slots_checksum[slot] = checksum; const gen = self.file_slots_generation[slot]; if (gen > max_generation) @panic("Open file generation overflow"); - h.generation = @intCast(gen); + h = .{ .idx = @intCast(slot), .generation = @intCast(gen), .path_checksum = checksum }; self.sendResponseSynced(req_id, .{ .open_file = h }); }, @@ -228,7 +232,7 @@ pub const LoaderCtx = struct { return; }; - const slot: usize = @intCast(file_handle.idx); + const slot: usize = file_handle.idx; self.fileDecRef(slot); }, @@ -241,7 +245,7 @@ pub const LoaderCtx = struct { return; }; - const slot: usize = @intCast(read_req.file.idx); + const slot: usize = read_req.file.idx; const xf = self.xfile_slots[slot]; // Prepare read request @@ -298,7 +302,7 @@ pub const LoaderCtx = struct { } } - self.worker_thread = null; + self.worker_done.store(true, .release); } // Loader side functions @@ -358,14 +362,14 @@ pub const LoaderCtx = struct { self.drainResponse(); std.Thread.yield() catch {}; } - // Now wait until the worker thread is done - while (self.worker_thread != null) { + // Now wait until the worker thread signals done + while (!self.worker_done.load(.acquire)) { self.drainResponse(); std.Thread.yield() catch {}; } - // When worker thread is null, it means the thread has exited self.drainResponse(); thread.join(); + self.worker_thread = null; } pub fn start(self: *Self) !void { @@ -374,6 +378,7 @@ pub const LoaderCtx = struct { } self.is_running = true; + self.worker_done.store(false, .monotonic); self.worker_thread = try std.Thread.spawn(.{ .allocator = self.alloc, @@ -383,14 +388,27 @@ pub const LoaderCtx = struct { logger.debug("Worker thread started", .{}); } - pub fn init(alloc: std.mem.Allocator) !Self { - return .{ - .alloc = alloc, - .file_slots = [_]?std.fs.File{null} ** max_file_slots, - .file_slots_generation = [_]u32{0} ** max_file_slots, - .loop = try xev.Loop.init(.{}), - .req_mem_pool = try std.heap.MemoryPool(XevReq).initPreheated(alloc, 16), - }; + pub fn initInPlace(self: *Self, alloc: std.mem.Allocator) !void { + self.alloc = alloc; + self.request_ring = ReqRing.init(); + self.result_ring = ResultRing.init(); + self.file_slots = @splat(null); + self.xfile_slots = undefined; + self.file_refcount = @splat(0); + self.file_slots_generation = std.mem.zeroes([max_file_slots]u32); + self.file_slots_checksum = undefined; + self.loop = try xev.Loop.init(.{}); + errdefer self.loop.deinit(); + self.worker_thread = null; + self.worker_done = std.atomic.Value(bool).init(false); + self.req_cnt = 0; + self.is_running = false; + self.is_draining = false; + self.last_slot = 0; + self.debug_max_req_id = std.math.maxInt(u64); + self.tick = 0; + self.debug_max_tick = std.math.maxInt(u64); + self.req_mem_pool = try std.heap.MemoryPool(XevReq).initPreheated(alloc, 16); } pub fn deinit(self: *Self) void { @@ -421,7 +439,9 @@ test "test dataloader" { var debug_alloc = std.heap.DebugAllocator(.{}).init; defer _ = debug_alloc.deinit(); - var ctx = try LoaderCtx.init(debug_alloc.allocator()); + const ctx = try debug_alloc.allocator().create(LoaderCtx); + defer debug_alloc.allocator().destroy(ctx); + try ctx.initInPlace(debug_alloc.allocator()); ctx.debug_max_req_id = 5; ctx.debug_max_tick = 1000; try ctx.start(); @@ -450,7 +470,9 @@ test "test dataloader blocked join" { var debug_alloc = std.heap.DebugAllocator(.{}).init; defer _ = debug_alloc.deinit(); - var ctx = try LoaderCtx.init(debug_alloc.allocator()); + const ctx = try debug_alloc.allocator().create(LoaderCtx); + defer debug_alloc.allocator().destroy(ctx); + try ctx.initInPlace(debug_alloc.allocator()); ctx.debug_max_req_id = 100; ctx.debug_max_tick = 1000; try ctx.start(); @@ -465,3 +487,22 @@ test "test dataloader blocked join" { try std.testing.expect(ctx.is_running == false); try std.testing.expectEqual(null, ctx.result_ring.dequeue()); } + +test "invalid file handle index is rejected" { + var debug_alloc = std.heap.DebugAllocator(.{}).init; + defer _ = debug_alloc.deinit(); + + const ctx = try debug_alloc.allocator().create(LoaderCtx); + defer debug_alloc.allocator().destroy(ctx); + try ctx.initInPlace(debug_alloc.allocator()); + defer ctx.deinit(); + + try std.testing.expectError( + LoaderError.InvalidFileHandle, + ctx.checkFilehandle(.{ + .idx = std.math.maxInt(u20), + .generation = 0, + .path_checksum = 0, + }), + ); +} diff --git a/lua_dataloader.zig b/lua_dataloader.zig index 5019399..5f075af 100644 --- a/lua_dataloader.zig +++ b/lua_dataloader.zig @@ -26,7 +26,7 @@ const c_u8ptr = [*c]const u8; pub const LoadedRow = extern struct { keys: [*c]c_u8ptr = null, data: [*c]c_u8ptr = null, - sizes: [*c]c_uint = null, + sizes: [*c]u64 = null, num_keys: c_uint = 0, }; @@ -80,14 +80,14 @@ pub const LuaDataLoader = struct { }, close_file: struct { // args - file_handle: u32, + file_handle: u64, }, add_entry: struct { // args key: [:0]const u8, offset: u64, size: u32, - file_handle: u32, + file_handle: u64, // state entry: ?*Row.Entry = null, }, @@ -143,7 +143,7 @@ pub const LuaDataLoader = struct { fn gCloseFile(lua: *Lua) !i32 { const loader = try lua.toUserdata(Self, 1); - const handle: u32 = try lua_rt.toUnsigned(lua, 2); + const handle: u64 = try lua_rt.toUnsigned64(lua, 2); loader.u_yielded_from = .{ .close_file = .{ .file_handle = handle, @@ -154,7 +154,7 @@ pub const LuaDataLoader = struct { fn gAddEntry(lua: *Lua) !i32 { const loader = try lua.toUserdata(Self, 1); - const handle: u32 = try lua_rt.toUnsigned(lua, 2); + const handle: u64 = try lua_rt.toUnsigned64(lua, 2); // We yield after this function & the string ref should live long enough const key = try lua.toString(3); const offset: u64 = @intFromFloat(try lua.toNumber(4)); @@ -357,16 +357,16 @@ pub const LuaDataLoader = struct { switch (payload) { .open_file => |f| { - std.debug.assert((self.u_yielded_from orelse @panic("Unresolved open_file req")) == .open_file); - lua_rt.pushUnsigned(self.lua, @intCast(@as(u32, @bitCast(f)))); + if (self.u_yielded_from == null or self.u_yielded_from.? != .open_file) { + return error.UnexpectedOpenFileResponse; + } + lua_rt.pushUnsigned64(self.lua, @bitCast(f)); self.u_resume_nargs = 1; self.u_yielded_from = null; }, .read_block => { - const rid = resp.request_id; - const kv = self.load_rid_to_row.fetchSwapRemove(rid) orelse @panic("read_block rid not found in map"); - const row = kv.value; - row.num_fullfilled += 1; + const kv = self.load_rid_to_row.fetchSwapRemove(resp.request_id) orelse @panic("read_block rid not found in map"); + kv.value.num_fullfilled += 1; }, } } @@ -557,13 +557,41 @@ pub const LuaDataLoader = struct { pub fn init(spec: LuaLoaderSpec, alloc: std.mem.Allocator) !*Self { var self = try alloc.create(Self); - const now = try std.time.Instant.now(); - self.* = .{ .alloc = alloc, .load_rid_to_row = try std.AutoArrayHashMapUnmanaged(u64, *Row).init(alloc, &.{}, &.{}), .last_instant = now, .last_log_instant = now }; errdefer alloc.destroy(self); + const now = try std.time.Instant.now(); + self.alloc = alloc; + self.loader = undefined; + self.lua = undefined; + self.u_loader_fn = .{}; + self.u_ctx = 0; + self.u_ctx_funcs_table = 0; + self.u_resume_nargs = 0; + self.u_yielded_from = null; + self.u_completed = false; + self.queue_size_rows = 4; + self.in_progress_row = null; + self.queue = .{}; + self.queue_len = 0; + self.row_buf_mutex = .{}; + self.free_list = .{}; + self.num_floating_rows = 0; + self.load_rid_to_row = try std.AutoArrayHashMapUnmanaged(u64, *Row).init(alloc, &.{}, &.{}); + self.last_instant = now; + self.last_log_instant = now; + self.mbps_smoothed = 0.0; + self.mbps_period_max = 0.0; + self.samples_count = 0; + errdefer self.load_rid_to_row.deinit(self.alloc); try self.newInprogressRow(); + errdefer { + const row = self.in_progress_row orelse unreachable; + row.deinit(); + self.alloc.destroy(row); + self.in_progress_row = null; + } - self.loader = try LoaderCtx.init(alloc); + try self.loader.initInPlace(alloc); errdefer self.loader.deinit(); try self.loader.start(); diff --git a/lua_rt.zig b/lua_rt.zig index d6afb4a..139e006 100644 --- a/lua_rt.zig +++ b/lua_rt.zig @@ -5,6 +5,10 @@ const msgpack = @import("msgpack.zig"); const logger = std.log.scoped(.lua_rt); +const max_exact_lua_handle = std.math.maxInt(u48); +const can_use_lua_unsigned64 = + (zlua.lang == .luau or zlua.lang == .lua52) and std.math.maxInt(zlua.Unsigned) >= max_exact_lua_handle; + const ScanCtx = struct { dir: std.fs.Dir, iter: std.fs.Dir.Iterator, @@ -88,6 +92,16 @@ pub fn toUnsigned(lua: *Lua, idx: i32) !u32 { } } +pub fn toUnsigned64(lua: *Lua, idx: i32) !u64 { + if (can_use_lua_unsigned64) { + return lua.toUnsigned(idx); + } else { + const f = try lua.toNumber(idx); + std.debug.assert(f >= 0 and f <= max_exact_lua_handle); + return @intFromFloat(f); + } +} + pub fn pushUnsigned(lua: *Lua, v: u32) void { if (zlua.lang == .luau or zlua.lang == .lua52) { lua.pushUnsigned(v); @@ -97,6 +111,16 @@ pub fn pushUnsigned(lua: *Lua, v: u32) void { } } +pub fn pushUnsigned64(lua: *Lua, v: u64) void { + if (can_use_lua_unsigned64) { + lua.pushUnsigned(@intCast(v)); + } else { + std.debug.assert(v <= max_exact_lua_handle); + const f: f64 = @floatFromInt(v); + lua.pushNumber(f); + } +} + pub fn printLuaErr(lua: *Lua, err: zlua.Error) zlua.Error { switch (err) { error.LuaError, error.LuaRuntime, error.LuaSyntax => { diff --git a/python/python.zig b/python/python.zig index fa21108..c1aa109 100644 --- a/python/python.zig +++ b/python/python.zig @@ -1,33 +1,44 @@ //! Python ABI3 bindings for ultar DataLoader. //! -//! ## Object Hierarchy +//! ## Ownership Layout //! //! ``` +//! module object +//! `-- ModuleState +//! |-- data_loader_type: ?*PyTypeObject +//! `-- loaded_row_type: ?*PyTypeObject +//! //! DataLoaderObject -//! ├── ob_base: PyObject (refcount managed by Python) -//! └── loader: *LuaLoaderCCtx (native, destroyed in dealloc) +//! |-- ob_base: PyObject (refcount managed by Python) +//! |-- typ: ?*PyTypeObject (cached heap type for dealloc) +//! `-- loader: ?*LuaLoaderCCtx (native, destroyed once in dealloc) //! //! LoadedRowObject -//! ├── ob_base: PyObject (refcount managed by Python) -//! ├── parent: ?*DataLoaderObject (incref'd reference to parent) -//! └── row: ?*LoadedRow (native ptr, reclaimed to parent's loader) +//! |-- ob_base: PyObject (refcount managed by Python) +//! |-- typ: ?*PyTypeObject (cached heap type for dealloc) +//! |-- parent: ?*DataLoaderObject (incref'd reference to parent) +//! `-- row: ?*LoadedRow (native ptr, reclaimed once before parent release) //! ``` //! //! ## Reference Ownership //! //! - `DataLoaderObject`: Created by `tp_new`, returned to Python with refcount=1. -//! Caller owns it. On dealloc, destroys the native loader. +//! Caller owns it. The instance caches its heap type pointer so custom dealloc can +//! call `tp_free` and discharge the heap-type reference without reading `PyObject` +//! headers directly. //! //! - `LoadedRowObject`: Created by `dataLoaderNext`, returned with refcount=1. //! Holds an incref'd reference to its parent `DataLoaderObject` to keep it alive. -//! On dealloc, reclaims the native row to the parent's loader, then decrefs parent. +//! It also caches its heap type pointer for the same dealloc rule. On dealloc, +//! it reclaims the native row to the parent's loader, then decrefs parent. //! //! - No reference cycles: LoadedRow → DataLoader (one-way ownership). //! //! ## Error Handling Pattern //! -//! Internal functions use Zig error semantics (`PyError!T`) with `errdefer` for cleanup. -//! C ABI wrappers catch errors and set Python exceptions. +//! Internal functions use Zig error semantics (`PyError!T`) with explicit partial-init +//! cleanup. Python-owned fields are zeroed before fallible work so clear/dealloc paths +//! tolerate partially initialized instances. //! //! ## Thread Safety //! @@ -52,16 +63,7 @@ const py = @cImport({ const zeros = std.mem.zeroes; -/// Get the type of a Python object (ABI3-safe replacement for Py_TYPE) -/// In limited API, Py_TYPE may not be exported as a symbol in all Python versions -inline fn pyType(obj: ?*py.PyObject) ?*py.PyTypeObject { - if (obj) |o| { - return o.ob_type; - } - return null; -} - -// ABI3-safe type checking helpers (avoid *_Check macros which use Py_TYPE internally) +// ABI3-safe type checking helpers (avoid *_Check macros that inspect object headers) inline fn isUnicode(obj: ?*py.PyObject) bool { if (obj) |o| { return py.PyObject_IsInstance(o, @ptrCast(@alignCast(&py.PyUnicode_Type))) == 1; @@ -83,23 +85,47 @@ inline fn isDict(obj: ?*py.PyObject) bool { return false; } +const ModuleState = struct { + data_loader_type: ?*py.PyTypeObject = null, + loaded_row_type: ?*py.PyTypeObject = null, +}; + +inline fn moduleState(module: *py.PyObject) *ModuleState { + return @ptrCast(@alignCast(py.PyModule_GetState(module).?)); +} + +inline fn moduleStateFromType(typ: *py.PyTypeObject) *ModuleState { + return @ptrCast(@alignCast(py.PyType_GetModuleState(typ).?)); +} + +inline fn loadedRowType(parent: *DataLoaderObject) PyError!*py.PyTypeObject { + const parent_type = parent.typ orelse return error.RuntimeError; + return moduleStateFromType(parent_type).loaded_row_type orelse return error.RuntimeError; +} + +fn freeHeapTypeInstance(typ: ?*py.PyTypeObject, self_obj: ?*py.PyObject) void { + const heap_type = typ orelse return; + const free_fn = py.PyType_GetSlot(heap_type, py.Py_tp_free) orelse unreachable; + const free: *const fn (?*anyopaque) callconv(.c) void = @ptrCast(@alignCast(free_fn)); + free(self_obj); + py.Py_DecRef(@ptrCast(@alignCast(heap_type))); +} + // Our DataLoader object const DataLoaderObject = extern struct { ob_base: py.PyObject, + typ: ?*py.PyTypeObject, loader: ?*LuaLoaderCCtx, }; // Our LoadedRow object (represents a single row from the dataloader) const LoadedRowObject = extern struct { ob_base: py.PyObject, + typ: ?*py.PyTypeObject, parent: ?*DataLoaderObject, // Keep reference to parent row: ?*LoadedRow, }; -// Type objects - stored as PyObject pointers (opaque with limited API) -var DataLoaderType: ?*py.PyTypeObject = null; -var LoadedRowType: ?*py.PyTypeObject = null; - // Slot definitions for DataLoader type const DataLoader_slots = [_]py.PyType_Slot{ .{ .slot = py.Py_tp_new, .pfunc = @ptrCast(@constCast(&dataLoaderNew)) }, @@ -281,14 +307,11 @@ fn dataLoaderNewImpl( const alloc_fn = py.PyType_GetSlot(typ, py.Py_tp_alloc) orelse return error.RuntimeError; const alloc: *const fn (?*py.PyTypeObject, py.Py_ssize_t) callconv(.c) ?*py.PyObject = @ptrCast(@alignCast(alloc_fn)); const self_obj = alloc(typ, 0) orelse return error.PythonException; - errdefer { - if (py.PyType_GetSlot(typ, py.Py_tp_free)) |f| { - const free: *const fn (?*anyopaque) callconv(.c) void = @ptrCast(@alignCast(f)); - free(self_obj); - } - } const self: *DataLoaderObject = @ptrCast(@alignCast(self_obj)); + self.typ = typ; + self.loader = null; + errdefer freeHeapTypeInstance(self.typ, self_obj); // Release GIL during heavy initialization const gil_state = py.PyEval_SaveThread(); @@ -353,21 +376,20 @@ fn dataLoaderNew(typ: ?*py.PyTypeObject, args: ?*py.PyObject, kwargs: ?*py.PyObj return @ptrCast(self); } -fn dataLoaderDealloc(self_obj: ?*py.PyObject) callconv(.c) void { - const self: *DataLoaderObject = @ptrCast(@alignCast(self_obj)); - +fn dataLoaderClear(self: *DataLoaderObject) void { if (self.loader) |loader| { - lua_dataloader.ultarDestroyLuaLoader(loader); self.loader = null; + lua_dataloader.ultarDestroyLuaLoader(loader); } +} - // Get the type and call tp_free - const typ = pyType(self_obj); - const free_fn = py.PyType_GetSlot(typ, py.Py_tp_free); - if (free_fn) |f| { - const free: *const fn (?*anyopaque) callconv(.c) void = @ptrCast(@alignCast(f)); - free(self_obj); - } +fn dataLoaderDealloc(self_obj: ?*py.PyObject) callconv(.c) void { + const self: *DataLoaderObject = @ptrCast(@alignCast(self_obj)); + + dataLoaderClear(self); + const typ = self.typ; + self.typ = null; + freeHeapTypeInstance(typ, self_obj); } fn dataLoaderRepr(_: ?*py.PyObject) callconv(.c) ?*py.PyObject { @@ -415,13 +437,16 @@ fn dataLoaderNext(self_obj: ?*py.PyObject) callconv(.c) ?*py.PyObject { /// On success: The returned Python object owns `row` and will reclaim it on dealloc. /// On failure: `row` is returned, caller must handle reclaim. fn wrapOwnedRow(parent: *DataLoaderObject, row: *LoadedRow) PyError!*py.PyObject { - const typ = LoadedRowType orelse return error.RuntimeError; + const typ = try loadedRowType(parent); const alloc_fn = py.PyType_GetSlot(typ, py.Py_tp_alloc) orelse return error.RuntimeError; const alloc: *const fn (?*py.PyTypeObject, py.Py_ssize_t) callconv(.c) ?*py.PyObject = @ptrCast(@alignCast(alloc_fn)); const self_obj = alloc(typ, 0) orelse return error.PythonException; const row_obj: *LoadedRowObject = @ptrCast(@alignCast(self_obj)); + row_obj.typ = typ; + row_obj.parent = null; + row_obj.row = null; row_obj.parent = parent; row_obj.row = row; @@ -430,37 +455,28 @@ fn wrapOwnedRow(parent: *DataLoaderObject, row: *LoadedRow) PyError!*py.PyObject return self_obj; } -fn loadedRowDealloc(self_obj: ?*py.PyObject) callconv(.c) void { - const self: *LoadedRowObject = @ptrCast(@alignCast(self_obj)); - - // Reclaim the row before releasing parent - // Null out row immediately to prevent double-free on any error path +fn loadedRowClear(self: *LoadedRowObject) void { const row_to_reclaim = self.row; self.row = null; if (self.parent) |parent| { + self.parent = null; if (row_to_reclaim) |row| { if (parent.loader) |loader| { lua_dataloader.ultarReclaimRow(loader, row); } - // Note: if parent.loader is null, the loader was already destroyed. - // This shouldn't happen with correct refcounting (we hold a ref to parent), - // but if it does, the row memory is already freed by ultarDestroyLuaLoader. } - self.parent = null; py.Py_DecRef(@ptrCast(parent)); } - // Note: if self.parent is null but row_to_reclaim was set, we have a bug - // in createLoadedRowObject. The row is leaked but we can't reclaim it - // without knowing which loader it belongs to. - - // Get the type and call tp_free - const typ = pyType(self_obj); - const free_fn = py.PyType_GetSlot(typ, py.Py_tp_free); - if (free_fn) |f| { - const free: *const fn (?*anyopaque) callconv(.c) void = @ptrCast(@alignCast(f)); - free(self_obj); - } +} + +fn loadedRowDealloc(self_obj: ?*py.PyObject) callconv(.c) void { + const self: *LoadedRowObject = @ptrCast(@alignCast(self_obj)); + + loadedRowClear(self); + const typ = self.typ; + self.typ = null; + freeHeapTypeInstance(typ, self_obj); } fn loadedRowRepr(self_obj: ?*py.PyObject) callconv(.c) ?*py.PyObject { @@ -645,6 +661,63 @@ const module_methods = [_]py.PyMethodDef{ std.mem.zeroes(py.PyMethodDef), }; +fn moduleTraverse(module_obj: ?*py.PyObject, visit: py.visitproc, arg: ?*anyopaque) callconv(.c) c_int { + const state = moduleState(module_obj.?); + + if (state.data_loader_type) |typ| { + if (visit.?(@ptrCast(@alignCast(typ)), arg) != 0) return -1; + } + if (state.loaded_row_type) |typ| { + if (visit.?(@ptrCast(@alignCast(typ)), arg) != 0) return -1; + } + return 0; +} + +fn moduleClear(module_obj: ?*py.PyObject) callconv(.c) c_int { + const state = moduleState(module_obj.?); + + if (state.data_loader_type) |typ| { + state.data_loader_type = null; + py.Py_DecRef(@ptrCast(@alignCast(typ))); + } + if (state.loaded_row_type) |typ| { + state.loaded_row_type = null; + py.Py_DecRef(@ptrCast(@alignCast(typ))); + } + return 0; +} + +fn moduleExec(module_obj: ?*py.PyObject) callconv(.c) c_int { + const module = module_obj orelse return -1; + const state = moduleState(module); + state.* = zeros(ModuleState); + + state.data_loader_type = @ptrCast(py.PyType_FromModuleAndSpec(module, &DataLoader_spec, null)); + if (state.data_loader_type == null) return -1; + + state.loaded_row_type = @ptrCast(py.PyType_FromModuleAndSpec(module, &LoadedRow_spec, null)); + if (state.loaded_row_type == null) { + _ = moduleClear(module_obj); + return -1; + } + + if (py.PyModule_AddObjectRef(module, "DataLoader", @ptrCast(@alignCast(state.data_loader_type))) < 0) { + _ = moduleClear(module_obj); + return -1; + } + if (py.PyModule_AddObjectRef(module, "LoadedRow", @ptrCast(@alignCast(state.loaded_row_type))) < 0) { + _ = moduleClear(module_obj); + return -1; + } + + return 0; +} + +const module_slots = [_]py.PyModuleDef_Slot{ + .{ .slot = py.Py_mod_exec, .value = @ptrCast(@constCast(&moduleExec)) }, + zeros(py.PyModuleDef_Slot), +}; + var module_def: py.PyModuleDef = undefined; var module_def_initialized = false; @@ -656,46 +729,16 @@ fn getModuleDef() *py.PyModuleDef { module_def.m_name = "ultar_dataloader._native"; module_def.m_doc = "Fast async dataloader with Lua scripting (Zig implementation)"; - module_def.m_size = -1; + module_def.m_size = @sizeOf(ModuleState); module_def.m_methods = @ptrCast(@constCast(&module_methods)); + module_def.m_slots = @ptrCast(@constCast(&module_slots)); + module_def.m_traverse = &moduleTraverse; + module_def.m_clear = &moduleClear; module_def_initialized = true; } return &module_def; } export fn PyInit__native() ?*py.PyObject { - // Create DataLoader type using PyType_FromSpec - DataLoaderType = @ptrCast(py.PyType_FromSpec(&DataLoader_spec)); - if (DataLoaderType == null) { - return null; - } - - // Create LoadedRow type using PyType_FromSpec - LoadedRowType = @ptrCast(py.PyType_FromSpec(&LoadedRow_spec)); - if (LoadedRowType == null) { - py.Py_DecRef(@ptrCast(@alignCast(DataLoaderType))); - return null; - } - - const m = py.PyModule_Create(getModuleDef()) orelse { - py.Py_DecRef(@ptrCast(@alignCast(DataLoaderType))); - py.Py_DecRef(@ptrCast(@alignCast(LoadedRowType))); - return null; - }; - - if (py.PyModule_AddObjectRef(m, "DataLoader", @ptrCast(@alignCast(DataLoaderType))) < 0) { - py.Py_DecRef(m); - py.Py_DecRef(@ptrCast(@alignCast(DataLoaderType))); - py.Py_DecRef(@ptrCast(@alignCast(LoadedRowType))); - return null; - } - - if (py.PyModule_AddObjectRef(m, "LoadedRow", @ptrCast(@alignCast(LoadedRowType))) < 0) { - py.Py_DecRef(m); - py.Py_DecRef(@ptrCast(@alignCast(DataLoaderType))); - py.Py_DecRef(@ptrCast(@alignCast(LoadedRowType))); - return null; - } - - return m; + return py.PyModuleDef_Init(getModuleDef()); } diff --git a/python/tests/test_limited_abi.py b/python/tests/test_limited_abi.py new file mode 100644 index 0000000..836e019 --- /dev/null +++ b/python/tests/test_limited_abi.py @@ -0,0 +1,213 @@ +import gc +import io +import os +import subprocess +import sys +import tarfile +import textwrap +from dataclasses import dataclass +from pathlib import Path + +import pytest + + +REPO_ROOT = Path(__file__).resolve().parents[2] +PACKAGE_ROOT = REPO_ROOT / "python" / "src" +INDEXER = REPO_ROOT / "zig-out" / "bin" / "indexer" +LOADER_SCRIPT = Path(__file__).with_name("loader_script.lua") + + +@dataclass(frozen=True) +class GeneratedFixture: + tar_path: Path + index_path: Path + + +def _make_tar_fixture(root: Path) -> GeneratedFixture: + tar_path = root / "generated.tar" + rows = { + "row0": { + ".txt": b"first row text", + ".json": b'{"row": 0}', + ".bin": bytes([0, 1, 2, 3]), + }, + "row1": { + ".txt": b"second row text", + ".json": b'{"row": 1}', + ".bin": bytes([4, 5, 6, 7]), + }, + "row2": { + ".txt": b"third row text", + ".json": b'{"row": 2}', + ".bin": bytes([8, 9, 10, 11]), + }, + } + + with tarfile.open(tar_path, "w") as archive: + for row_name, entries in rows.items(): + for suffix, payload in entries.items(): + member = tarfile.TarInfo(f"{row_name}{suffix}") + member.size = len(payload) + archive.addfile(member, io.BytesIO(payload)) + + subprocess.run([str(INDEXER), "-f", str(tar_path)], cwd=root, check=True) + return GeneratedFixture(tar_path=tar_path, index_path=Path(f"{tar_path}.utix")) + + +def _pythonpath_env() -> dict[str, str]: + env = os.environ.copy() + path_parts = [str(PACKAGE_ROOT)] + if current := env.get("PYTHONPATH"): + path_parts.append(current) + env["PYTHONPATH"] = os.pathsep.join(path_parts) + return env + + +def _run_subprocess( + case: str, generated_fixture: GeneratedFixture +) -> subprocess.CompletedProcess[str]: + return subprocess.run( + [ + sys.executable, + "-c", + textwrap.dedent(case), + str(LOADER_SCRIPT), + str(generated_fixture.tar_path), + str(generated_fixture.index_path), + ], + cwd=REPO_ROOT, + env=_pythonpath_env(), + capture_output=True, + text=True, + ) + + +def _assert_clean_exit(result: subprocess.CompletedProcess[str]) -> None: + assert result.returncode == 0, ( + f"subprocess exited with {result.returncode}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + + +@pytest.fixture(scope="module") +def generated_fixture(tmp_path_factory: pytest.TempPathFactory) -> GeneratedFixture: + root = tmp_path_factory.mktemp("limited-abi") + return _make_tar_fixture(root) + + +def test_subprocess_repeated_construction_and_cleanup( + generated_fixture: GeneratedFixture, +) -> None: + result = _run_subprocess( + """ + import gc + import sys + from pathlib import Path + + from ultar_dataloader import DataLoader + + script = Path(sys.argv[1]).read_text() + config = { + "tar_path": sys.argv[2], + "idx_path": sys.argv[3], + "max_rows": "2", + } + + for _ in range(40): + loader = DataLoader(src=script, config=config) + rows = list(loader) + assert len(rows) == 2 + assert rows[0].keys() == [".txt", ".json", ".bin"] + assert rows[0][".txt"] == b"first row text" + assert rows[0].to_dict()[".json"] == b'{"row": 0}' + del rows + del loader + for _ in range(3): + gc.collect() + """, + generated_fixture, + ) + + _assert_clean_exit(result) + + +def test_subprocess_constructor_failure_cleanup( + generated_fixture: GeneratedFixture, +) -> None: + result = _run_subprocess( + """ + import gc + from ultar_dataloader import DataLoader + + for _ in range(40): + try: + DataLoader(src="this is not lua") + except RuntimeError: + pass + else: + raise AssertionError("expected RuntimeError for invalid Lua source") + for _ in range(3): + gc.collect() + """, + generated_fixture, + ) + + _assert_clean_exit(result) + + +def test_subprocess_row_error_paths_and_parent_release( + generated_fixture: GeneratedFixture, +) -> None: + result = _run_subprocess( + """ + import gc + import sys + from pathlib import Path + + from ultar_dataloader import DataLoader + + script = Path(sys.argv[1]).read_text() + config = { + "tar_path": sys.argv[2], + "idx_path": sys.argv[3], + "max_rows": "1", + } + + for _ in range(40): + loader = DataLoader(src=script, config=config) + row = next(iter(loader)) + del loader + gc.collect() + + assert row[0] == b"first row text" + assert row[-1] == bytes([0, 1, 2, 3]) + + try: + row[99] + except IndexError: + pass + else: + raise AssertionError("expected IndexError for out-of-range row access") + + try: + row[".missing"] + except KeyError: + pass + else: + raise AssertionError("expected KeyError for missing row key") + + del row + for _ in range(3): + gc.collect() + """, + generated_fixture, + ) + + _assert_clean_exit(result) + + +def test_generated_fixture_reuse_is_stable(generated_fixture: GeneratedFixture) -> None: + assert generated_fixture.tar_path.exists() + assert generated_fixture.index_path.exists() + gc.collect()