From f515056dc16a68d9f25beae52092119fa39c282b Mon Sep 17 00:00:00 2001 From: Raiden1411 <67233402+Raiden1411@users.noreply.github.com> Date: Fri, 22 Aug 2025 22:46:16 +0100 Subject: [PATCH 1/3] std: update websocket server implementation This commit addresses the #24937 issue and updates the websocket implementation to support both fragments and longer messages. --- lib/std/Build/Fuzz.zig | 6 +- lib/std/Build/WebServer.zig | 14 +- lib/std/http/Server.zig | 303 +++++++++++++++++++++++++++++------- 3 files changed, 261 insertions(+), 62 deletions(-) diff --git a/lib/std/Build/Fuzz.zig b/lib/std/Build/Fuzz.zig index bc10f7907a2e..162c86b075c2 100644 --- a/lib/std/Build/Fuzz.zig +++ b/lib/std/Build/Fuzz.zig @@ -270,7 +270,7 @@ pub fn sendUpdate( @ptrCast(coverage_map.source_locations), coverage_map.coverage.string_bytes.items, }; - try socket.writeMessageVec(&iovecs, .binary); + try socket.writeFrameVec(&iovecs, .binary); } const header: abi.CoverageUpdateHeader = .{ @@ -281,7 +281,7 @@ pub fn sendUpdate( @ptrCast(&header), @ptrCast(seen_pcs), }; - try socket.writeMessageVec(&iovecs, .binary); + try socket.writeFrameVec(&iovecs, .binary); prev.unique_runs = unique_runs; } @@ -292,7 +292,7 @@ pub fn sendUpdate( @ptrCast(&header), @ptrCast(coverage_map.entry_points.items), }; - try socket.writeMessageVec(&iovecs, .binary); + try socket.writeFrameVec(&iovecs, .binary); prev.entry_points = coverage_map.entry_points.items.len; } diff --git a/lib/std/Build/WebServer.zig b/lib/std/Build/WebServer.zig index 451f4b9d34d4..1c84e657079d 100644 --- a/lib/std/Build/WebServer.zig +++ b/lib/std/Build/WebServer.zig @@ -253,9 +253,10 @@ fn accept(ws: *WebServer, connection: std.net.Server.Connection) void { switch (request.upgradeRequested()) { .websocket => |opt_key| { const key = opt_key orelse return log.err("missing websocket key", .{}); - var web_socket = request.respondWebSocket(.{ .key = key }) catch { + var web_socket = request.respondWebSocket(.{ .key = key, .allocator = ws.gpa }) catch { return log.err("failed to respond web socket: {t}", .{connection_writer.err.?}); }; + defer web_socket.close(0); ws.serveWebSocket(&web_socket) catch |err| { log.err("failed to serve websocket: {t}", .{err}); return; @@ -298,7 +299,7 @@ fn serveWebSocket(ws: *WebServer, sock: *http.Server.WebSocket) !noreturn { .steps_len = @intCast(ws.all_steps.len), }; var bufs: [3][]const u8 = .{ @ptrCast(&hello_header), ws.step_names_trailing, prev_step_status_bits }; - try sock.writeMessageVec(&bufs, .binary); + try sock.writeFrameVec(&bufs, .binary); } var prev_fuzz: Fuzz.Previous = .init; @@ -323,7 +324,8 @@ fn serveWebSocket(ws: *WebServer, sock: *http.Server.WebSocket) !noreturn { // Temporarily unlock, then re-lock after the message is sent. ws.time_report_mutex.unlock(); defer ws.time_report_mutex.lock(); - try sock.writeMessage(owned_msg, .binary); + + try sock.writeFrame(owned_msg, .binary); } } @@ -332,7 +334,7 @@ fn serveWebSocket(ws: *WebServer, sock: *http.Server.WebSocket) !noreturn { if (build_status != prev_build_status) { prev_build_status = build_status; const msg: abi.StatusUpdate = .{ .new = build_status }; - try sock.writeMessage(@ptrCast(&msg), .binary); + try sock.writeFrame(@ptrCast(&msg), .binary); } } @@ -353,7 +355,7 @@ fn serveWebSocket(ws: *WebServer, sock: *http.Server.WebSocket) !noreturn { }; for (cur, prev, byte_idx * 4..) |cur_status, prev_status, step_idx| { const msg: abi.StepUpdate = .{ .step_idx = @intCast(step_idx), .bits = .{ .status = cur_status } }; - if (cur_status != prev_status) try sock.writeMessage(@ptrCast(&msg), .binary); + if (cur_status != prev_status) try sock.writeFrame(@ptrCast(&msg), .binary); } prev_byte.* = cur_byte; } @@ -364,7 +366,7 @@ fn serveWebSocket(ws: *WebServer, sock: *http.Server.WebSocket) !noreturn { } fn recvWebSocketMessages(ws: *WebServer, sock: *http.Server.WebSocket) void { while (true) { - const msg = sock.readSmallMessage() catch return; + const msg = sock.readMessage() catch return; if (msg.opcode != .binary) continue; if (msg.data.len == 0) continue; const tag: abi.ToServerTag = @enumFromInt(msg.data[0]); diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index b64253f975bf..3d0e2f13d875 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -520,6 +520,7 @@ pub const Request = struct { } pub const WebSocketOptions = struct { + allocator: std.mem.Allocator, /// The value from `UpgradeRequest.websocket` (sec-websocket-key header value). key: []const u8, reason: ?[]const u8 = null, @@ -559,9 +560,14 @@ pub const Request = struct { try out.writeAll("\r\n"); return .{ + .fragment = .{ + .alloc_writer = .init(options.allocator), + .message_type = null, + }, .input = request.server.reader.in, - .output = request.server.out, .key = options.key, + .output = request.server.out, + .storage = .init(options.allocator), }; } @@ -652,8 +658,13 @@ pub const Request = struct { /// See https://tools.ietf.org/html/rfc6455 pub const WebSocket = struct { + /// Structure that builds websocket frames that are fragmented. + fragment: Fragment, + /// The websocket handshake key key: []const u8, input: *Reader, + /// Writer that is used to store large messages that the input buffer cannot handle. + storage: Writer.Allocating, output: *Writer, pub const Header0 = packed struct(u8) { @@ -685,52 +696,117 @@ pub const WebSocket = struct { _, }; - pub const ReadSmallTextMessageError = error{ - ConnectionClose, + pub const ReadTextMessageError = error{ UnexpectedOpCode, MessageTooBig, MissingMaskBit, + InvalidUtf8Payload, + ControlFrameTooBig, + FragmentedControl, + UnnegociatedReservedBits, + UnexpectedFragment, ReadFailed, EndOfStream, }; - pub const SmallMessage = struct { - /// Can be text, binary, or ping. + /// Wrapper around a websocket fragmented frame. + pub const Fragment = struct { + const Error = std.mem.Allocator.Error || Writer.Error; + + /// Writer to preserve the fragmented payload. + alloc_writer: Writer.Allocating, + /// The type of message that the fragment is. + /// + /// Control fragment's are not supported. + message_type: ?Opcode, + + /// Clears any allocated memory. + pub fn deinit(ws: *Fragment) void { + ws.alloc_writer.deinit(); + } + + /// Writes the payload into the buffer. + pub fn writeAll(ws: *Fragment, payload: []const u8) Error!void { + try ws.alloc_writer.ensureUnusedCapacity(payload.len); + return ws.alloc_writer.writer.writeAll(payload); + } + + /// Resets the fragment but keeps the allocated memory. + /// + /// Also reset the message type back to null. + pub fn reset(ws: *Fragment) void { + ws.alloc_writer.shrinkRetainingCapacity(0); + ws.message_type = null; + } + + /// Returns a slice of the currently written values on the buffer. + pub fn slice(ws: *Fragment) []u8 { + return ws.alloc_writer.written(); + } + + /// Returns the total amount of bytes that were written. + pub fn size(ws: Fragment) usize { + return ws.alloc_writer.writer.end; + } + }; + + pub const WebsocketMessage = struct { opcode: Opcode, data: []u8, }; - /// Reads the next message from the WebSocket stream, failing if the - /// message does not fit into the input buffer. The returned memory points - /// into the input buffer and is invalidated on the next read. - pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage { + /// Sends close frame with the exit code and frees any allocated memory + pub fn close(ws: *WebSocket, exit_code: u16) void { + ws.writeCloseFrame(exit_code) catch {}; + ws.deinit(); + } + + /// Clears any allocated memory. + pub fn deinit(ws: *WebSocket) void { + ws.storage.deinit(); + ws.fragment.deinit(); + } + + /// Reads the next message from the WebSocket stream. + /// + /// The returned message can either point to: + /// + /// * Input buffer + /// * Storage writer buffer + /// * The fragments writer buffer + /// + /// Either of them will have their pointer invalidated on the next read message + pub fn readMessage(ws: *WebSocket) !WebsocketMessage { const in = ws.input; while (true) { const header = try in.takeArray(2); - const h0: Header0 = @bitCast(header[0]); - const h1: Header1 = @bitCast(header[1]); - switch (h0.opcode) { - .text, .binary, .pong, .ping => {}, - .connection_close => return error.ConnectionClose, - .continuation => return error.UnexpectedOpCode, - _ => return error.UnexpectedOpCode, - } + const op_head: Header0 = @bitCast(header[0]); + const payload_head: Header1 = @bitCast(header[1]); - if (!h0.fin) return error.MessageTooBig; - if (!h1.mask) return error.MissingMaskBit; + if (!payload_head.mask) + return error.MissingMaskBit; - const len: usize = switch (h1.payload_len) { + if (@bitCast(op_head.rsv1) or @bitCast(op_head.rsv2) or @bitCast(op_head.rsv3)) + return error.UnnegociatedReservedBits; + + const total = switch (payload_head.payload_len) { .len16 => try in.takeInt(u16, .big), .len64 => std.math.cast(usize, try in.takeInt(u64, .big)) orelse return error.MessageTooBig, - else => @intFromEnum(h1.payload_len), + _ => @intFromEnum(payload_head.payload_len), }; - if (len > in.buffer.len) return error.MessageTooBig; + const mask: u32 = @bitCast((try in.takeArray(4)).*); - const payload = try in.take(len); + const payload = blk: { + if (total < in.buffered().len) + break :blk try in.take(total); - // Skip pongs. - if (h0.opcode == .pong) continue; + try ws.storage.ensureUnusedCapacity(total); + try in.streamExact(&ws.storage.writer, total); + defer ws.storage.shrinkRetainingCapacity(0); + + break :blk ws.storage.written(); + }; // The last item may contain a partial word of unused data. const floored_len = (payload.len / 4) * 4; @@ -740,63 +816,184 @@ pub const WebSocket = struct { for (payload[floored_len..], mask_bytes[0 .. payload.len - floored_len]) |*leftover, m| leftover.* ^= m; - return .{ - .opcode = h0.opcode, - .data = payload, - }; + switch (op_head.opcode) { + .text, + .binary, + => { + if (!op_head.fin) { + try ws.fragment.writeAll(payload); + ws.fragment.message_type = op_head.opcode; + + continue; + } + + if (ws.fragment.size() != 0) + return error.UnexpectedFragment; + + if (op_head.opcode == .text and !std.unicode.utf8ValidateSlice(payload)) + return error.InvalidUtf8Payload; + + return .{ + .opcode = op_head.opcode, + .data = payload, + }; + }, + .continuation, + => { + const message_type = ws.fragment.message_type orelse return error.FragmentedControl; + + if (!op_head.fin) { + try ws.fragment.writeAll(payload); + continue; + } + + try ws.fragment.writeAll(payload); + defer ws.fragment.reset(); + + const slice = ws.fragment.slice(); + + if (message_type == .text and !std.unicode.utf8ValidateSlice(slice)) + return error.InvalidUtf8Payload; + + return .{ + .opcode = message_type, + .data = slice, + }; + }, + .ping, + .pong, + .connection_close, + => { + if (total > 125 or !op_head.fin) + return error.ControlFrameTooBig; + + return .{ + .opcode = op_head.opcode, + .data = payload, + }; + }, + _ => return error.UnexpectedOpCode, + } } } - pub fn writeMessage(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { - var bufs: [1][]const u8 = .{data}; - try writeMessageVecUnflushed(ws, &bufs, op); - try ws.output.flush(); + /// Writes to the server a close frame with a provided `exit_code`. + /// + /// For more details please see: https://www.rfc-editor.org/rfc/rfc6455#section-5.5.1 + pub fn writeCloseFrame(ws: *WebSocket, exit_code: u16) Writer.Error!void { + if (exit_code == 0) { + @branchHint(.likely); + + return ws.writeFrame("", .connection_close); + } + + var buffer: [2]u8 = undefined; + std.mem.writeInt(u16, buffer[0..2], exit_code, .big); + + return ws.writeFrame(buffer[0..], .connection_close); } - pub fn writeMessageUnflushed(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { + /// Writes a websocket frame directly to the socket. + /// + /// The message is unmasked according to the websocket RFC. + /// More details here: https://www.rfc-editor.org/rfc/rfc6455#section-6.1 + pub fn writeFrame(ws: *WebSocket, data: []const u8, opcode: Opcode) Writer.Error!void { var bufs: [1][]const u8 = .{data}; - try writeMessageVecUnflushed(ws, &bufs, op); + try ws.writeFrameVecUnflushed(&bufs, opcode); + + return ws.flush(); } - pub fn writeMessageVec(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void { - try writeMessageVecUnflushed(ws, data, op); - try ws.output.flush(); + /// Writes a websocket frame directly to the socket. + /// + /// The message is unmasked according to the websocket RFC. + /// More details here: https://www.rfc-editor.org/rfc/rfc6455#section-6.1 + pub fn writeFrameVec(ws: *WebSocket, data: [][]const u8, opcode: Opcode) Writer.Error!void { + try ws.writeFrameVecUnflushed(data, opcode); + + return ws.flush(); + } + + /// Writes a websocket frame directly to the socket. Doesn't flush the writers buffer. + /// + /// The fin bit is set to true on this sent frame. + /// + /// To send a fragmented message please see `writeHeaderFrameVecUnflushed` + /// and pair it with `writeBodyVecUnflushed` and don't forget to flush! + /// + /// The message is unmasked according to the websocket RFC. + /// More details here: https://www.rfc-editor.org/rfc/rfc6455#section-6.1 + pub fn writeFrameVecUnflushed(ws: *WebSocket, messages: [][]const u8, opcode: Opcode) Writer.Error!void { + try ws.writeHeaderFrameVecUnflushed(messages, opcode, true); + + return ws.writeBodyVecUnflushed(messages); } - pub fn writeMessageVecUnflushed(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void { - const total_len = l: { + /// Writes a websocket message directly to the socket without the header + /// + /// The message is unmasked according to the websocket RFC. + /// More details here: https://www.rfc-editor.org/rfc/rfc6455#section-6.1 + pub fn writeBodyVecUnflushed(ws: *WebSocket, data: [][]const u8) Writer.Error!void { + return ws.output.writeVecAll(data); + } + + /// Generates the websocket header frame based on the message len and the opcode provided. + /// + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-------+-+-------------+-------------------------------+ + /// |F|R|R|R| opcode|M| Payload len | Extended payload length | + /// |I|S|S|S| (4) |A| (7) | (16/64) | + /// |N|V|V|V| |S| | (if payload len==126/127) | + /// | |1|2|3| |K| | | + /// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + /// | Extended payload length continued, if payload len == 127 | + /// + - - - - - - - - - - - - - - - +-------------------------------+ + /// | |Masking-key, if MASK set to 1 | + /// +-------------------------------+-------------------------------+ + /// | Masking-key (continued) | Payload Data | + /// +-------------------------------- - - - - - - - - - - - - - - - + + /// : Payload Data continued ... : + /// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + /// | Payload Data continued ... | + /// +---------------------------------------------------------------+ + pub fn writeHeaderFrameVecUnflushed(ws: *WebSocket, messages: [][]const u8, opcode: Opcode, fin_bit: bool) Writer.Error!void { + const total_len = len: { var total_len: u64 = 0; - for (data) |iovec| total_len += iovec.len; - break :l total_len; + for (messages) |iovec| total_len += iovec.len; + break :len total_len; }; - const out = ws.output; - try out.writeByte(@bitCast(@as(Header0, .{ - .opcode = op, - .fin = true, + + try ws.output.writeByte(@bitCast(@as(Header0, .{ + .opcode = opcode, + .fin = fin_bit, }))); + switch (total_len) { - 0...125 => try out.writeByte(@bitCast(@as(Header1, .{ + 0...125 => return ws.output.writeByte(@bitCast(@as(Header1, .{ .payload_len = @enumFromInt(total_len), .mask = false, }))), - 126...0xffff => { - try out.writeByte(@bitCast(@as(Header1, .{ + 126...0xFFFF => { + try ws.output.writeByte(@bitCast(@as(Header1, .{ .payload_len = .len16, .mask = false, }))); - try out.writeInt(u16, @intCast(total_len), .big); + + return ws.output.writeInt(u16, @intCast(total_len), .big); }, else => { - try out.writeByte(@bitCast(@as(Header1, .{ + try ws.output.writeByte(@bitCast(@as(Header1, .{ .payload_len = .len64, .mask = false, }))); - try out.writeInt(u64, total_len, .big); + + return ws.output.writeInt(u64, total_len, .big); }, } - try out.writeVecAll(data); } + /// Drains all of the remaining buffered data. pub fn flush(ws: *WebSocket) Writer.Error!void { try ws.output.flush(); } From 12a682d490a35a651ba776b250642cb95630d01f Mon Sep 17 00:00:00 2001 From: Raiden1411 <67233402+Raiden1411@users.noreply.github.com> Date: Sat, 23 Aug 2025 14:31:31 +0100 Subject: [PATCH 2/3] std: support close messages on websocket server --- lib/std/Build/WebServer.zig | 2 +- lib/std/http/Server.zig | 30 +++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/lib/std/Build/WebServer.zig b/lib/std/Build/WebServer.zig index 1c84e657079d..7a373a2c715e 100644 --- a/lib/std/Build/WebServer.zig +++ b/lib/std/Build/WebServer.zig @@ -256,7 +256,7 @@ fn accept(ws: *WebServer, connection: std.net.Server.Connection) void { var web_socket = request.respondWebSocket(.{ .key = key, .allocator = ws.gpa }) catch { return log.err("failed to respond web socket: {t}", .{connection_writer.err.?}); }; - defer web_socket.close(0); + defer web_socket.close(.{ .exit_code = 0 }); ws.serveWebSocket(&web_socket) catch |err| { log.err("failed to serve websocket: {t}", .{err}); return; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 3d0e2f13d875..0ed24b91ec20 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -750,14 +750,19 @@ pub const WebSocket = struct { } }; + pub const CloseMessage = struct { + exit_code: u16, + data: []const u8 = "", + }; + pub const WebsocketMessage = struct { opcode: Opcode, data: []u8, }; /// Sends close frame with the exit code and frees any allocated memory - pub fn close(ws: *WebSocket, exit_code: u16) void { - ws.writeCloseFrame(exit_code) catch {}; + pub fn close(ws: *WebSocket, options: CloseMessage) void { + ws.writeCloseFrame(options) catch {}; ws.deinit(); } @@ -787,6 +792,7 @@ pub const WebSocket = struct { if (!payload_head.mask) return error.MissingMaskBit; + // TODO: Remove check here for op_head.rsv1 once compression is readded. if (@bitCast(op_head.rsv1) or @bitCast(op_head.rsv2) or @bitCast(op_head.rsv3)) return error.UnnegociatedReservedBits; @@ -821,6 +827,9 @@ pub const WebSocket = struct { .binary, => { if (!op_head.fin) { + if (ws.fragment.message_type != null) + return error.UnexpectedFragment; + try ws.fragment.writeAll(payload); ws.fragment.message_type = op_head.opcode; @@ -843,6 +852,9 @@ pub const WebSocket = struct { const message_type = ws.fragment.message_type orelse return error.FragmentedControl; if (!op_head.fin) { + if (ws.fragment.message_type == null) + return error.UnexpectedFragment; + try ws.fragment.writeAll(payload); continue; } @@ -880,17 +892,21 @@ pub const WebSocket = struct { /// Writes to the server a close frame with a provided `exit_code`. /// /// For more details please see: https://www.rfc-editor.org/rfc/rfc6455#section-5.5.1 - pub fn writeCloseFrame(ws: *WebSocket, exit_code: u16) Writer.Error!void { - if (exit_code == 0) { + pub fn writeCloseFrame(ws: *WebSocket, options: CloseMessage) Writer.Error!void { + if (options.exit_code == 0) { @branchHint(.likely); - return ws.writeFrame("", .connection_close); + return ws.writeFrame(options.data, .connection_close); } var buffer: [2]u8 = undefined; - std.mem.writeInt(u16, buffer[0..2], exit_code, .big); + std.mem.writeInt(u16, buffer[0..2], options.exit_code, .big); + + var bufs: [2][]const u8 = .{ buffer[0..], options.data }; + try ws.writeHeaderFrameVecUnflushed(&bufs, .connection_close, true); + try ws.writeBodyVecUnflushed(&bufs); - return ws.writeFrame(buffer[0..], .connection_close); + return ws.flush(); } /// Writes a websocket frame directly to the socket. From 2968e8ef1c5a8499a3acfdc3394292b27cd2a88a Mon Sep 17 00:00:00 2001 From: Raiden1411 <67233402+Raiden1411@users.noreply.github.com> Date: Mon, 25 Aug 2025 10:53:22 +0100 Subject: [PATCH 3/3] std: update logic on payload read --- lib/std/http/Server.zig | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 0ed24b91ec20..389064367bfd 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -792,7 +792,8 @@ pub const WebSocket = struct { if (!payload_head.mask) return error.MissingMaskBit; - // TODO: Remove check here for op_head.rsv1 once compression is readded. + // TODO: Remove check here for op_head.rsv1 if compression + // is added in the future if (@bitCast(op_head.rsv1) or @bitCast(op_head.rsv2) or @bitCast(op_head.rsv3)) return error.UnnegociatedReservedBits; @@ -804,14 +805,15 @@ pub const WebSocket = struct { const mask: u32 = @bitCast((try in.takeArray(4)).*); const payload = blk: { - if (total < in.buffered().len) - break :blk try in.take(total); + if (total > in.buffer.len) { + try ws.storage.ensureUnusedCapacity(total); + try in.streamExact(&ws.storage.writer, total); + defer ws.storage.shrinkRetainingCapacity(0); - try ws.storage.ensureUnusedCapacity(total); - try in.streamExact(&ws.storage.writer, total); - defer ws.storage.shrinkRetainingCapacity(0); + break :blk ws.storage.written(); + } - break :blk ws.storage.written(); + break :blk try in.take(total); }; // The last item may contain a partial word of unused data. @@ -827,9 +829,6 @@ pub const WebSocket = struct { .binary, => { if (!op_head.fin) { - if (ws.fragment.message_type != null) - return error.UnexpectedFragment; - try ws.fragment.writeAll(payload); ws.fragment.message_type = op_head.opcode; @@ -852,9 +851,6 @@ pub const WebSocket = struct { const message_type = ws.fragment.message_type orelse return error.FragmentedControl; if (!op_head.fin) { - if (ws.fragment.message_type == null) - return error.UnexpectedFragment; - try ws.fragment.writeAll(payload); continue; }