Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 85 additions & 44 deletions dataloader.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ const logger = std.log.scoped(.dataloader);
const wlog = std.log.scoped(.dataloader_io_thread);

pub const FileHandle = packed struct {
_: u14 = 0,
idx: u20,
generation: u20,
path_checksum: u8,
_: u16 = 0,
};

const max_file_slots = std.math.maxInt(@FieldType(FileHandle, "idx"));
const max_generation = std.math.maxInt(@FieldType(FileHandle, "generation"));
const max_file_slots = std.math.maxInt(u20);
const max_generation = std.math.maxInt(u20);

pub const ReadBlockReq = struct {
base: u64,
Expand Down Expand Up @@ -88,6 +88,7 @@ pub const LoaderCtx = struct {

loop: xev.Loop,
worker_thread: ?std.Thread = null,
worker_done: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),

req_cnt: u64 = 0,
is_running: bool = false,
Expand All @@ -104,25 +105,26 @@ pub const LoaderCtx = struct {
// This isn't called too often, just linear scan
for (0..max_file_slots) |offset| {
const i = (self.last_slot + offset) % max_file_slots;
const f = self.file_slots[i];
const f = self.file_slots[0..][i];
if (f == null) {
self.last_slot = (i + 1) % max_file_slots;
return .{
.idx = @intCast(i),
.generation = 0,
.path_checksum = 0,
};
return .{ .idx = @intCast(i), .generation = 0, .path_checksum = 0 };
}
}

return LoaderError.TooManyOpenFiles;
}

fn checkFilehandle(self: *Self, file: FileHandle) !void {
const slot: usize = @intCast(file.idx);
if (self.file_slots[slot] == null) return LoaderError.InvalidFileHandle;
if (self.file_slots_generation[slot] != file.generation or self.file_slots_checksum[slot] != file.path_checksum) {
logger.warn("File handle {} is corrupted, current generation: {}, checksum: {}", .{ file, self.file_slots_generation[slot], self.file_slots_checksum[slot] });
const slot: usize = file.idx;
const file_slots = self.file_slots[0..];
const generations = self.file_slots_generation[0..];
const checksums = self.file_slots_checksum[0..];

if (slot >= file_slots.len) return LoaderError.InvalidFileHandle;
if (file_slots[slot] == null) return LoaderError.InvalidFileHandle;
if (generations[slot] != file.generation or checksums[slot] != file.path_checksum) {
logger.warn("File handle {} is corrupted, current generation: {}, checksum: {}", .{ file, generations[slot], checksums[slot] });
return LoaderError.InvalidFileHandle;
}
}
Expand Down Expand Up @@ -153,12 +155,15 @@ pub const LoaderCtx = struct {
const self = ud orelse unreachable;

const xreq: *XevReq = @fieldParentPtr("c", c);
const slot: usize = xreq.req.file.idx;
const request_id = xreq.request_id;

const actual = r catch {
self.sendResponseSynced(xreq.request_id, LoaderError.ReadError);
self.fileDecRef(slot);
self.req_mem_pool.destroy(xreq);
self.sendResponseSynced(request_id, LoaderError.ReadError);
return .disarm;
};
const slot: usize = @intCast(xreq.req.file.idx);
self.fileDecRef(slot);

if (actual != xreq.req.result_buffer.len) {
Expand All @@ -173,25 +178,24 @@ pub const LoaderCtx = struct {

fn fileAddRef(self: *Self, slot: usize) void {
// Should start on 1
std.debug.assert(self.file_refcount[slot] > 0);
self.file_refcount[slot] += 1;
std.debug.assert(self.file_refcount[0..][slot] > 0);
self.file_refcount[0..][slot] += 1;
}

fn fileDecRef(self: *Self, slot: usize) void {
std.debug.assert(self.file_refcount[slot] > 0);
self.file_refcount[slot] -= 1;
if (self.file_refcount[slot] == 0) {
const f = self.file_slots[slot] orelse unreachable;
std.debug.assert(self.file_refcount[0..][slot] > 0);
self.file_refcount[0..][slot] -= 1;
if (self.file_refcount[0..][slot] == 0) {
const f = self.file_slots[0..][slot] orelse unreachable;
f.close();
self.file_slots[slot] = null;
self.file_slots[0..][slot] = null;
}
}

fn handleReq(self: *Self, req_id: u64, req: Request) void {
switch (req) {
.open_file => |open_req| {
const file_path = open_req.file_path;
wlog.debug("Req {}: open_file: file = {s}", .{ req_id, file_path });
wlog.debug("Req {}: open_file: file = {s}", .{ req_id, open_req.file_path });

var h = self.findFreeFileSlot() catch |err| {
self.sendResponseSynced(req_id, err);
Expand All @@ -206,16 +210,16 @@ pub const LoaderCtx = struct {
const xf = xev.File.init(f) catch unreachable;

// Commit state
const slot: usize = @intCast(h.idx);
h.path_checksum = path_checksum(file_path);
const slot: usize = h.idx;
const checksum = path_checksum(open_req.file_path);
self.file_slots[slot] = f;
self.xfile_slots[slot] = xf;
self.file_refcount[slot] = 1;
self.file_slots_generation[slot] += 1; // gen 0 is reserved to catch errors
self.file_slots_checksum[slot] = h.path_checksum;
self.file_slots_checksum[slot] = checksum;
const gen = self.file_slots_generation[slot];
if (gen > max_generation) @panic("Open file generation overflow");
h.generation = @intCast(gen);
h = .{ .idx = @intCast(slot), .generation = @intCast(gen), .path_checksum = checksum };

self.sendResponseSynced(req_id, .{ .open_file = h });
},
Expand All @@ -228,7 +232,7 @@ pub const LoaderCtx = struct {
return;
};

const slot: usize = @intCast(file_handle.idx);
const slot: usize = file_handle.idx;

self.fileDecRef(slot);
},
Expand All @@ -241,7 +245,7 @@ pub const LoaderCtx = struct {
return;
};

const slot: usize = @intCast(read_req.file.idx);
const slot: usize = read_req.file.idx;
const xf = self.xfile_slots[slot];

// Prepare read request
Expand Down Expand Up @@ -298,7 +302,7 @@ pub const LoaderCtx = struct {
}
}

self.worker_thread = null;
self.worker_done.store(true, .release);
}

// Loader side functions
Expand Down Expand Up @@ -358,14 +362,14 @@ pub const LoaderCtx = struct {
self.drainResponse();
std.Thread.yield() catch {};
}
// Now wait until the worker thread is done
while (self.worker_thread != null) {
// Now wait until the worker thread signals done
while (!self.worker_done.load(.acquire)) {
self.drainResponse();
std.Thread.yield() catch {};
}
// When worker thread is null, it means the thread has exited
self.drainResponse();
thread.join();
self.worker_thread = null;
}

pub fn start(self: *Self) !void {
Expand All @@ -374,6 +378,7 @@ pub const LoaderCtx = struct {
}

self.is_running = true;
self.worker_done.store(false, .monotonic);

self.worker_thread = try std.Thread.spawn(.{
.allocator = self.alloc,
Expand All @@ -383,14 +388,27 @@ pub const LoaderCtx = struct {
logger.debug("Worker thread started", .{});
}

pub fn init(alloc: std.mem.Allocator) !Self {
return .{
.alloc = alloc,
.file_slots = [_]?std.fs.File{null} ** max_file_slots,
.file_slots_generation = [_]u32{0} ** max_file_slots,
.loop = try xev.Loop.init(.{}),
.req_mem_pool = try std.heap.MemoryPool(XevReq).initPreheated(alloc, 16),
};
pub fn initInPlace(self: *Self, alloc: std.mem.Allocator) !void {
self.alloc = alloc;
self.request_ring = ReqRing.init();
self.result_ring = ResultRing.init();
self.file_slots = @splat(null);
self.xfile_slots = undefined;
self.file_refcount = @splat(0);
self.file_slots_generation = std.mem.zeroes([max_file_slots]u32);
self.file_slots_checksum = undefined;
self.loop = try xev.Loop.init(.{});
errdefer self.loop.deinit();
self.worker_thread = null;
self.worker_done = std.atomic.Value(bool).init(false);
self.req_cnt = 0;
self.is_running = false;
self.is_draining = false;
self.last_slot = 0;
self.debug_max_req_id = std.math.maxInt(u64);
self.tick = 0;
self.debug_max_tick = std.math.maxInt(u64);
self.req_mem_pool = try std.heap.MemoryPool(XevReq).initPreheated(alloc, 16);
}

pub fn deinit(self: *Self) void {
Expand Down Expand Up @@ -421,7 +439,9 @@ test "test dataloader" {
var debug_alloc = std.heap.DebugAllocator(.{}).init;
defer _ = debug_alloc.deinit();

var ctx = try LoaderCtx.init(debug_alloc.allocator());
const ctx = try debug_alloc.allocator().create(LoaderCtx);
defer debug_alloc.allocator().destroy(ctx);
try ctx.initInPlace(debug_alloc.allocator());
ctx.debug_max_req_id = 5;
ctx.debug_max_tick = 1000;
try ctx.start();
Expand Down Expand Up @@ -450,7 +470,9 @@ test "test dataloader blocked join" {
var debug_alloc = std.heap.DebugAllocator(.{}).init;
defer _ = debug_alloc.deinit();

var ctx = try LoaderCtx.init(debug_alloc.allocator());
const ctx = try debug_alloc.allocator().create(LoaderCtx);
defer debug_alloc.allocator().destroy(ctx);
try ctx.initInPlace(debug_alloc.allocator());
ctx.debug_max_req_id = 100;
ctx.debug_max_tick = 1000;
try ctx.start();
Expand All @@ -465,3 +487,22 @@ test "test dataloader blocked join" {
try std.testing.expect(ctx.is_running == false);
try std.testing.expectEqual(null, ctx.result_ring.dequeue());
}

test "invalid file handle index is rejected" {
var debug_alloc = std.heap.DebugAllocator(.{}).init;
defer _ = debug_alloc.deinit();

const ctx = try debug_alloc.allocator().create(LoaderCtx);
defer debug_alloc.allocator().destroy(ctx);
try ctx.initInPlace(debug_alloc.allocator());
defer ctx.deinit();

try std.testing.expectError(
LoaderError.InvalidFileHandle,
ctx.checkFilehandle(.{
.idx = std.math.maxInt(u20),
.generation = 0,
.path_checksum = 0,
}),
);
}
56 changes: 42 additions & 14 deletions lua_dataloader.zig
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const c_u8ptr = [*c]const u8;
pub const LoadedRow = extern struct {
keys: [*c]c_u8ptr = null,
data: [*c]c_u8ptr = null,
sizes: [*c]c_uint = null,
sizes: [*c]u64 = null,
num_keys: c_uint = 0,
};

Expand Down Expand Up @@ -80,14 +80,14 @@ pub const LuaDataLoader = struct {
},
close_file: struct {
// args
file_handle: u32,
file_handle: u64,
},
add_entry: struct {
// args
key: [:0]const u8,
offset: u64,
size: u32,
file_handle: u32,
file_handle: u64,
// state
entry: ?*Row.Entry = null,
},
Expand Down Expand Up @@ -143,7 +143,7 @@ pub const LuaDataLoader = struct {

fn gCloseFile(lua: *Lua) !i32 {
const loader = try lua.toUserdata(Self, 1);
const handle: u32 = try lua_rt.toUnsigned(lua, 2);
const handle: u64 = try lua_rt.toUnsigned64(lua, 2);
loader.u_yielded_from = .{
.close_file = .{
.file_handle = handle,
Expand All @@ -154,7 +154,7 @@ pub const LuaDataLoader = struct {

fn gAddEntry(lua: *Lua) !i32 {
const loader = try lua.toUserdata(Self, 1);
const handle: u32 = try lua_rt.toUnsigned(lua, 2);
const handle: u64 = try lua_rt.toUnsigned64(lua, 2);
// We yield after this function & the string ref should live long enough
const key = try lua.toString(3);
const offset: u64 = @intFromFloat(try lua.toNumber(4));
Expand Down Expand Up @@ -357,16 +357,16 @@ pub const LuaDataLoader = struct {

switch (payload) {
.open_file => |f| {
std.debug.assert((self.u_yielded_from orelse @panic("Unresolved open_file req")) == .open_file);
lua_rt.pushUnsigned(self.lua, @intCast(@as(u32, @bitCast(f))));
if (self.u_yielded_from == null or self.u_yielded_from.? != .open_file) {
return error.UnexpectedOpenFileResponse;
}
lua_rt.pushUnsigned64(self.lua, @bitCast(f));
self.u_resume_nargs = 1;
self.u_yielded_from = null;
},
.read_block => {
const rid = resp.request_id;
const kv = self.load_rid_to_row.fetchSwapRemove(rid) orelse @panic("read_block rid not found in map");
const row = kv.value;
row.num_fullfilled += 1;
const kv = self.load_rid_to_row.fetchSwapRemove(resp.request_id) orelse @panic("read_block rid not found in map");
kv.value.num_fullfilled += 1;
},
}
}
Expand Down Expand Up @@ -557,13 +557,41 @@ pub const LuaDataLoader = struct {

pub fn init(spec: LuaLoaderSpec, alloc: std.mem.Allocator) !*Self {
var self = try alloc.create(Self);
const now = try std.time.Instant.now();
self.* = .{ .alloc = alloc, .load_rid_to_row = try std.AutoArrayHashMapUnmanaged(u64, *Row).init(alloc, &.{}, &.{}), .last_instant = now, .last_log_instant = now };
errdefer alloc.destroy(self);
const now = try std.time.Instant.now();
self.alloc = alloc;
self.loader = undefined;
self.lua = undefined;
self.u_loader_fn = .{};
self.u_ctx = 0;
self.u_ctx_funcs_table = 0;
self.u_resume_nargs = 0;
self.u_yielded_from = null;
self.u_completed = false;
self.queue_size_rows = 4;
self.in_progress_row = null;
self.queue = .{};
self.queue_len = 0;
self.row_buf_mutex = .{};
self.free_list = .{};
self.num_floating_rows = 0;
self.load_rid_to_row = try std.AutoArrayHashMapUnmanaged(u64, *Row).init(alloc, &.{}, &.{});
self.last_instant = now;
self.last_log_instant = now;
self.mbps_smoothed = 0.0;
self.mbps_period_max = 0.0;
self.samples_count = 0;
errdefer self.load_rid_to_row.deinit(self.alloc);

try self.newInprogressRow();
errdefer {
const row = self.in_progress_row orelse unreachable;
row.deinit();
self.alloc.destroy(row);
self.in_progress_row = null;
}

self.loader = try LoaderCtx.init(alloc);
try self.loader.initInPlace(alloc);
errdefer self.loader.deinit();
try self.loader.start();

Expand Down
Loading
Loading