diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 45173f0..f5f261a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: go-version: ['1.24'] - os: [ubuntu-latest, windows-latest, darwin-latest] + os: [ubuntu-latest, windows-latest, macos-latest] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 diff --git a/TODO.md b/TODO.md index 38f36d0..571dfc0 100644 --- a/TODO.md +++ b/TODO.md @@ -16,6 +16,7 @@ sooner rather than later - [ ] chunked transfer responses (with channels) - [ ] write cache headers on file handler responses - [ ] CORS headers +- [ ] support CORS preflight requests (OPTIONS) - [ ] custom reader type for request reading (alternative to bufio, greedy reader that reads until end of http request) later: diff --git a/common/ascii/ascii.go b/common/ascii/ascii.go new file mode 100644 index 0000000..ee73b21 --- /dev/null +++ b/common/ascii/ascii.go @@ -0,0 +1,139 @@ +package ascii + +const ( + NUL = byte(iota) + SOH + STX + ETX + EOT + ENQ + ACK + BEL + BS + HT + LF + VT + FF + CR + SO + SI + DLE + DC1 + DC2 + DC3 + DC4 + NAK + SYN + ETB + CAN + EM + SUB + ESC + FS // 0x1C File Separator + GS // 0x1D Group Separator + RS // 0x1E Record Separator + US // 0x1F Unit Separator + + SPACE // 0x20 + EXCL // 0x21 ! + QUOTE // 0x22 " + HASH // 0x23 # + DOLLAR // 0x24 $ + PERCENT // 0x25 % + AMPERSAND // 0x26 & + APOSTROPHE // 0x27 ' + LPAREN // 0x28 ( + RPAREN // 0x29 ) + ASTERISK // 0x2A * + PLUS // 0x2B + + COMMA // 0x2C , + MINUS // 0x2D - + DOT // 0x2E . + SLASH // 0x2F / + + ZERO // 0x30 + ONE // 0x31 + TWO // 0x32 + THREE // 0x33 + FOUR // 0x34 + FIVE // 0x35 + SIX // 0x36 + SEVEN // 0x37 + EIGHT // 0x38 + NINE // 0x39 + + COLON // 0x3A : + SEMICOLON // 0x3B ; + LT // 0x3C < + EQ // 0x3D = + GT // 0x3E > + QUESTION // 0x3F ? + AT // 0x40 @ + + A // 0x41 + B // 0x42 + C // 0x43 + D // 0x44 + E // 0x45 + F // 0x46 + G // 0x47 + H // 0x48 + I // 0x49 + J // 0x4A + K // 0x4B + L // 0x4C + M // 0x4D + N // 0x4E + O // 0x4F + P // 0x50 + Q // 0x51 + R // 0x52 + S // 0x53 + T // 0x54 + U // 0x55 + V // 0x56 + W // 0x57 + X // 0x58 + Y // 0x59 + Z // 0x5A + + LBRACKET // 0x5B [ + BACKSLASH // 0x5C \ + RBRACKET // 0x5D ] + CARET // 0x5E ^ + UNDERSCORE // 0x5F _ + GRAVE // 0x60 ` + + LC_A // 0x61 + LC_B // 0x62 + LC_C // 0x63 + LC_D // 0x64 + LC_E // 0x65 + LC_F // 0x66 + LC_G // 0x67 + LC_H // 0x68 + LC_I // 0x69 + LC_J // 0x6A + LC_K // 0x6B + LC_L // 0x6C + LC_M // 0x6D + LC_N // 0x6E + LC_O // 0x6F + LC_P // 0x70 + LC_Q // 0x71 + LC_R // 0x72 + LC_S // 0x73 + LC_T // 0x74 + LC_U // 0x75 + LC_V // 0x76 + LC_W // 0x77 + LC_X // 0x78 + LC_Y // 0x79 + LC_Z // 0x7A + + LBRACE // 0x7B { + PIPE // 0x7C | + RBRACE // 0x7D } + TILDE // 0x7E ~ + DEL // 0x7F +) diff --git a/handlers/brotli.go b/handlers/brotli.go index 39228bc..c8ebcea 100644 --- a/handlers/brotli.go +++ b/handlers/brotli.go @@ -22,25 +22,67 @@ func (b brotliHandler) HandleRequest(ctx http.Context) error { if !reqAcceptsBrotli(ctx.Request) { return nil } - bbuf, err := castBody(ctx.Response.Body) - if err != nil { - return err - } - newBuf, err := b.compressBody(bbuf) - if err != nil { - return err + + if c, ok := ctx.Response.Body.(chan http.StreamedResponseChunk); ok { + err := b.handleChannel(ctx, c) + if err != nil { + return err + } + } else { //NOT a channel, do normal stuff + bbuf, err := castBody(ctx.Response.Body) + if err != nil { + return err + } + newBuf, err := b.compressBody(bbuf) + if err != nil { + ctx.Response.AddHeader(http.Header{ + Name: "Content-Length", + Value: strconv.Itoa(len(newBuf)), + }) + return err + } + //assign body to response + ctx.Response.Body = newBuf } - //assign body to response and set compression header + //set compression header ctx.Response.AddHeader(http.Header{ Name: "Content-Encoding", Value: "br", }) - ctx.Response.AddHeader(http.Header{ - Name: "Content-Length", - Value: strconv.Itoa(len(newBuf)), + return nil +} + +func (b brotliHandler) handleChannel(ctx http.Context, c chan http.StreamedResponseChunk) error { + tChan := make(chan http.StreamedResponseChunk, 1) + ctx.Response.Body = tChan + var newBuf bytes.Buffer + writer := brotli.NewWriterOptions(&newBuf, brotli.WriterOptions{ + Quality: b.quality, + LGWin: 0, }) - ctx.Response.Body = newBuf + go func() { + defer close(tChan) + for chunk := range c { + if chunk.Err != nil { + tChan <- chunk + return + } + _, err := writer.Write(chunk.Data) + if err != nil { + tChan <- http.StreamedResponseChunk{Err: err} + return + } + } + err := writer.Close() + chunk := http.StreamedResponseChunk{} + if err != nil { + chunk.Err = err + } else { + chunk.Data = newBuf.Bytes() + } + tChan <- chunk + }() return nil } diff --git a/handlers/response_headers.go b/handlers/response_headers.go index ac81402..5aefa15 100644 --- a/handlers/response_headers.go +++ b/handlers/response_headers.go @@ -16,29 +16,41 @@ var ServerHeader = http.Header{ Value: fmt.Sprintf("%s/%s", SERVER_NAME, VERSION), } +var timeFunc = time.Now + func ResponseHeadersHandler(ctx http.Context) error { //Server ctx.Response.AddHeader(ServerHeader) //Date ctx.Response.AddHeader(http.Header{ Name: "Date", - Value: common.ToHttpDateFormat(time.Now()), + Value: common.ToHttpDateFormat(timeFunc()), }) - if !ctx.Response.Headers.HasHeader("Content-Length") { - var length int - switch v := ctx.Response.Body.(type) { - case string: - length = len(v) - case []byte: - length = len(v) - default: - //TODO: figure out what to do - return nil - } + _, bodyIsChannel := ctx.Response.Body.(chan http.StreamedResponseChunk) + if bodyIsChannel { + //delete content-length, add transfer-encoding: chunked instead + delete(ctx.Response.Headers, "Content-Length") ctx.Response.AddHeader(http.Header{ - Name: "Content-Length", - Value: strconv.Itoa(length), + Name: "Transfer-Encoding", + Value: "chunked", }) + } else { + if !ctx.Response.Headers.HasHeader("Content-Length") { + var length int + switch v := ctx.Response.Body.(type) { + case string: + length = len(v) + case []byte: + length = len(v) + default: + //TODO: figure out what to do + return nil + } + ctx.Response.AddHeader(http.Header{ + Name: "Content-Length", + Value: strconv.Itoa(length), + }) + } } tryWriteConnectionHeader(ctx) diff --git a/handlers/time_helper.go b/handlers/time_helper.go new file mode 100644 index 0000000..9205a74 --- /dev/null +++ b/handlers/time_helper.go @@ -0,0 +1,11 @@ +//go:build test + +package handlers + +import ( + "time" +) + +func SetTimeFunc(f func() time.Time) { + timeFunc = f +} diff --git a/http/headers.go b/http/headers.go index cc71c35..0a3b2b8 100644 --- a/http/headers.go +++ b/http/headers.go @@ -1,5 +1,9 @@ package http +import ( + "sort" +) + type Header struct { Name string Value string @@ -15,3 +19,16 @@ func (m Headers) HasHeader(key string) bool { } return false } + +// Sorted returns Headers as a slice of Header, where the headers are sorted alphanumerically ascending by their Name property +func (m Headers) Sorted() []Header { + headers := make([]Header, 0, len(m)) + for _, header := range m { + headers = append(headers, header) + } + // Sort by Name ascending + sort.Slice(headers, func(i, j int) bool { + return headers[i].Name < headers[j].Name + }) + return headers +} diff --git a/http/response.go b/http/response.go index 188993d..19f9221 100644 --- a/http/response.go +++ b/http/response.go @@ -4,8 +4,11 @@ import ( "bufio" "bytes" "fmt" + "gophttp/common/ascii" "net" + "strconv" "strings" + "time" ) type Response struct { @@ -32,13 +35,7 @@ func (r Response) WriteToConn(conn net.Conn) error { return err } //write all headers - for _, header := range r.Headers { - _, err = w.WriteString(fmt.Sprintf("%s: %s\n", header.Name, strings.TrimRight(header.Value, "\n"))) - if err != nil { - return err - } - } - _, err = w.WriteString("\n") + err = r.writeHeaders(w) if err != nil { return err } @@ -64,6 +61,28 @@ func (r Response) WriteToConn(conn net.Conn) error { if err != nil { return err } + } else if c, ok := r.Body.(chan StreamedResponseChunk); ok { + //if we get a byte slice channel, start a loop where we read from said channel until it closes + //we block here and do not create another goroutine because we need to wait until we fully wrote our response + //before moving on to the next request in the TCP connection + for { + select { + case chunk, more := <-c: + if !more { + err = handleChunk(StreamedResponseChunk{Data: make([]byte, 0)}, w) + if err != nil { + return err + } + return w.Flush() + } + err = handleChunk(chunk, w) + if err != nil { + return err + } + case <-time.After(15 * time.Second): //TODO: make configurable + return fmt.Errorf("read timeout on body channel") + } + } } else { //log and return err (500) return fmt.Errorf("%v: %T", ErrUnknownBodyType, r.Body) @@ -71,3 +90,46 @@ func (r Response) WriteToConn(conn net.Conn) error { return w.Flush() } + +func (r Response) writeHeaders(w *bufio.Writer) error { + var err error + for _, header := range r.Headers.Sorted() { + _, err = w.WriteString(fmt.Sprintf("%s: %s\n", header.Name, strings.TrimRight(header.Value, "\n"))) + if err != nil { + return err + } + } + _, err = w.WriteString("\n") + return err +} + +func handleChunk(chunk StreamedResponseChunk, w *bufio.Writer) error { + if chunk.Err != nil { + return chunk.Err + } + //write length of chunk + chunkLen := len(chunk.Data) + chunkLenHex := strconv.FormatInt(int64(chunkLen), 16) + _, err := w.Write([]byte(chunkLenHex)) + if err != nil { + return err + } + //write CR+LF + err = writeCRLF(w) + if err != nil { + return err + } + //write chunk + _, err = w.Write(chunk.Data) + if err != nil { + return err + } + //write CR+LF + err = writeCRLF(w) + return err +} + +func writeCRLF(w *bufio.Writer) error { + _, err := w.Write([]byte{ascii.CR, ascii.LF}) + return err +} diff --git a/http/streamed_response_chunk.go b/http/streamed_response_chunk.go new file mode 100644 index 0000000..a36af62 --- /dev/null +++ b/http/streamed_response_chunk.go @@ -0,0 +1,6 @@ +package http + +type StreamedResponseChunk struct { + Data []byte + Err error +} diff --git a/main.go b/main.go index da8bb07..4606636 100644 --- a/main.go +++ b/main.go @@ -29,12 +29,12 @@ func main() { //instantiate server serv := server.NewHttpServer(4488) - err = serv.AddRoutes(pwd) + err = serv.AddFileRoutes(pwd) if err != nil { panic(err) } err = serv.StartServing(ctx) if err != nil { - slog.Error("error in server thread", err) + slog.Error("error in server thread", "err", err.Error()) } } diff --git a/server/server.go b/server/server.go index 8583c7c..bbc160a 100644 --- a/server/server.go +++ b/server/server.go @@ -41,8 +41,8 @@ func (s *HttpServer) nextReqIndex() uint64 { return s.reqIndex } -// AddRoutes searches for all files and directories under path and adds a handler for each of them to the server -func (s *HttpServer) AddRoutes(path string) error { +// AddFileRoutes searches for all files and directories under path and adds a handler for each of them to the server +func (s *HttpServer) AddFileRoutes(path string) error { files, err := common.ListFilesRecursive(path) if err != nil { panic(err) @@ -100,6 +100,13 @@ func (s *HttpServer) addDirRoute(dir string) error { return err } +func (s *HttpServer) AddHandler(route string, method http.Method, handler handlers.Handler) error { + if route == "" { + return fmt.Errorf("invalid route: can't be empty string") + } + return s.insertRoute(route, method, handler) +} + func (s *HttpServer) StartServing(ctx context.Context) error { sock, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port)) tcpSock := sock.(*net.TCPListener) @@ -110,7 +117,7 @@ func (s *HttpServer) StartServing(ctx context.Context) error { defer func(sock net.Listener) { err := sock.Close() if err != nil { - slog.Error("error closing socket", err) + slog.Error("error closing socket", "err", err.Error()) } }(sock) @@ -129,7 +136,7 @@ func (s *HttpServer) StartServing(ctx context.Context) error { func (s *HttpServer) connectLoop(tcpSock *net.TCPListener) { err := tcpSock.SetDeadline(time.Now().Add(1 * time.Second)) if err != nil { - slog.Error("error setting socket deadline", err) + slog.Error("error setting socket deadline", "err", err.Error()) return } conn, err := tcpSock.Accept() @@ -139,7 +146,7 @@ func (s *HttpServer) connectLoop(tcpSock *net.TCPListener) { //ignore timeout errors as they are expected return } - slog.Error("failed accepting tcp socket connection", err) + slog.Error("failed accepting tcp socket connection", "err", err.Error()) return } go s.handleConnection(conn) @@ -150,7 +157,7 @@ func (s *HttpServer) handleConnection(conn net.Conn) { defer func(conn net.Conn) { err := conn.Close() if err != nil { - slog.Error("failed closing socket", err) + slog.Error("failed closing socket", "err", err.Error()) } }(conn) diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..831d53e --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,358 @@ +//go:build test + +package server_test + +import ( + "bufio" + "bytes" + "context" + "fmt" + "net" + "strings" + "testing" + "time" + + "gophttp/handlers" + "gophttp/http" + "gophttp/server" +) + +func TestCustomAddedHandlerIsCalled(t *testing.T) { + // Start server on a random port + port := 8089 // Use a test port unlikely to be in use + httpServer := server.NewHttpServer(port) + + // Register a custom handler + const testPath = "/test" + const expectedBody = "Hello, test!" + err := httpServer.AddHandler(testPath, http.GET, handlers.HandlerFunc(func(ctx http.Context) error { + ctx.Response.Status = http.StatusOK + ctx.Response.Body = expectedBody + ctx.Response.AddHeader(http.Header{Name: "Content-Type", Value: "text/plain"}) + return nil + })) + if err != nil { + t.Fatalf("failed setting up handler: %v", err) + } + ctx, cfunc := context.WithCancel(context.Background()) + servClosed := make(chan bool, 1) + + // Start server in background + go func() { + defer func() { servClosed <- true }() + err := httpServer.StartServing(ctx) + if err != nil { + t.Errorf("failed to serve: %v", err) + return + } + }() + // Wait for server to start + time.Sleep(200 * time.Millisecond) + + // Connect to server + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + // Send HTTP GET request + req := fmt.Sprintf("GET %s HTTP/1.1\r\nHost: localhost\r\n\r\n", testPath) + _, err = conn.Write([]byte(req)) + if err != nil { + t.Fatalf("failed to write request: %v", err) + } + + // Read response + reader := bufio.NewReader(conn) + var response strings.Builder + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + response.WriteString(line) + if line == "\r\n" || line == "\n" { + break // End of headers + } + } + // Read body + body, _ := reader.ReadString('\n') + response.WriteString(body) + + if !strings.Contains(response.String(), expectedBody) { + t.Errorf("expected body %q in response, got: %q", expectedBody, response.String()) + } + cfunc() + //block until server is dead (this is to prevent having multiple tests that create servers running at the same time) + //((assuming that all tests run sequentially not parallel, which seems to be the case)) + _ = <-servClosed +} + +func TestStreamedResponseWithDelay(t *testing.T) { + port := 8090 // Use a different test port + httpServer := server.NewHttpServer(port) + + const testPath = "/stream" + seg1 := []byte("segment1-") + seg2 := []byte("segment2!") + + err := httpServer.AddHandler(testPath, http.GET, handlers.HandlerFunc(func(ctx http.Context) error { + ctx.Response.Status = http.StatusOK + ctx.Response.AddHeader(http.Header{Name: "Content-Type", Value: "text/plain"}) + ch := make(chan http.StreamedResponseChunk) + ctx.Response.Body = ch + go func() { + ch <- http.StreamedResponseChunk{Data: seg1} + time.Sleep(1 * time.Second) + ch <- http.StreamedResponseChunk{Data: seg2} + close(ch) + }() + return nil + })) + if err != nil { + t.Fatalf("failed setting up handler: %v", err) + } + ctx, cfunc := context.WithCancel(context.Background()) + servClosed := make(chan bool, 1) + + go func() { + defer func() { servClosed <- true }() + err := httpServer.StartServing(ctx) + if err != nil { + t.Errorf("failed to serve: %v", err) + return + } + }() + time.Sleep(200 * time.Millisecond) + + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + req := fmt.Sprintf("GET %s HTTP/1.1\r\nHost: localhost\r\n\r\n", testPath) + _, err = conn.Write([]byte(req)) + if err != nil { + t.Fatalf("failed to write request: %v", err) + } + + reader := bufio.NewReader(conn) + var response strings.Builder + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + response.WriteString(line) + if line == "\r\n" || line == "\n" { + break + } + } + // Read first segment + body1 := make([]byte, len(seg1)+4) + _, err = reader.Read(body1) + if err != nil { + t.Fatalf("failed to read first segment: %v", err) + } + // Wait for the second segment (should be delayed) + body2 := make([]byte, len(seg2)+4) + _, err = reader.Read(body2) + if err != nil { + t.Fatalf("failed to read second segment: %v", err) + } + response.Write(body1) + response.Write(body2) + + body := response.String() + + if !strings.Contains(body, "9\r\n"+string(seg1)) || !strings.Contains(body, "9\r\n"+string(seg2)) { + t.Errorf("expected streamed segments in response, got: %q", body) + } + + if !strings.Contains(body, "Transfer-Encoding: chunked") { + t.Error("expected 'Transfer-Encoding: chunked' header but didn't find it") + } + cfunc() + _ = <-servClosed +} + +func TestStreamedResponseWithDelayAndBrotliAccept(t *testing.T) { + timeFunc := func() time.Time { + return time.Date(2025, 7, 13, 11, 57, 50, 0, time.UTC) + } + + handlers.SetTimeFunc(timeFunc) + + port := 8091 // Use a different test port + httpServer := server.NewHttpServer(port) + + const testPath = "/stream-brotli" + seg1 := bytes.Repeat([]byte("L"), 256) + seg2 := []byte("\n\n") + seg3 := bytes.Repeat([]byte("F"), 256) + + handlerFn := handlers.HandlerFunc(func(ctx http.Context) error { + ctx.Response.Status = http.StatusOK + ctx.Response.AddHeader(http.Header{Name: "Content-Type", Value: "text/plain"}) + ch := make(chan http.StreamedResponseChunk) + ctx.Response.Body = ch + go func() { + defer close(ch) + ch <- http.StreamedResponseChunk{Data: seg1} + time.Sleep(100 * time.Millisecond) + ch <- http.StreamedResponseChunk{Data: seg2} + time.Sleep(100 * time.Millisecond) + ch <- http.StreamedResponseChunk{Data: seg3} + }() + return nil + }) + compressionHandler := handlers.NewCompressionHandler() + handler := handlers.ComposeHandlers(handlerFn, compressionHandler) + + err := httpServer.AddHandler(testPath, http.GET, handler) + if err != nil { + t.Fatalf("failed setting up handler: %v", err) + } + ctx, cfunc := context.WithCancel(context.Background()) + servClosed := make(chan bool, 1) + + go func() { + defer func() { servClosed <- true }() + err := httpServer.StartServing(ctx) + if err != nil { + t.Errorf("failed to serve: %v", err) + return + } + }() + time.Sleep(200 * time.Millisecond) + + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + req := fmt.Sprintf("GET %s HTTP/1.1\r\nHost: localhost\r\nAccept-Encoding: br\r\n\r\n", testPath) + _, err = conn.Write([]byte(req)) + if err != nil { + t.Fatalf("failed to write request: %v", err) + } + + reader := bufio.NewReader(conn) + //brotli buffers all segments before compressing, meaning we only get one big compressed chunk + body1 := make([]byte, 203) + _, err = reader.Read(body1) + if err != nil { + t.Fatalf("failed to read first segment: %v", err) + } + expectedBytes := []byte{ + 0x48, 0x54, 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, 0x20, 0x32, 0x30, 0x30, 0x20, 0x4f, 0x4b, 0x0a, + 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x3a, 0x20, 0x6b, 0x65, 0x65, 0x70, + 0x2d, 0x61, 0x6c, 0x69, 0x76, 0x65, 0x0a, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x45, + 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x3a, 0x20, 0x62, 0x72, 0x0a, 0x43, 0x6f, 0x6e, 0x74, + 0x65, 0x6e, 0x74, 0x2d, 0x54, 0x79, 0x70, 0x65, 0x3a, 0x20, 0x74, 0x65, 0x78, 0x74, 0x2f, 0x70, + 0x6c, 0x61, 0x69, 0x6e, 0x0a, 0x44, 0x61, 0x74, 0x65, 0x3a, 0x20, 0x53, 0x75, 0x6e, 0x2c, 0x20, + 0x31, 0x33, 0x20, 0x4a, 0x75, 0x6c, 0x20, 0x32, 0x30, 0x32, 0x35, 0x20, 0x31, 0x31, 0x3a, 0x35, + 0x37, 0x3a, 0x35, 0x30, 0x20, 0x47, 0x4d, 0x54, 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x3a, + 0x20, 0x67, 0x6f, 0x70, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x30, 0x2e, 0x31, 0x0a, 0x54, 0x72, 0x61, + 0x6e, 0x73, 0x66, 0x65, 0x72, 0x2d, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x3a, 0x20, + 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x0a, 0x0a, 0x31, 0x31, 0x0d, 0x0a, 0x1b, 0x01, 0x02, + 0x00, 0x24, 0x15, 0x8c, 0x98, 0x6a, 0xb1, 0xcd, 0x0a, 0x40, 0xe4, 0x3e, 0x47, 0x00, 0x0d, 0x0a, + 0x30, 0x0d, 0x0a, 0x0d, 0x0a, + } + + if len(body1) < len(expectedBytes) { + t.Errorf("expected len %d but body is only %d long", len(expectedBytes), len(body1)) + } + for i, ex := range expectedBytes { + if ex != body1[i] { + t.Errorf("expected byte %d but got %d at idx %d", ex, body1[i], i) + } + } + + cfunc() + _ = <-servClosed +} + +func TestKeepAliveSupport(t *testing.T) { + port := 8092 // Use a different test port + httpServer := server.NewHttpServer(port) + + const testPath = "/test" + const expectedBody = "Hello, test!" + err := httpServer.AddHandler(testPath, http.GET, handlers.HandlerFunc(func(ctx http.Context) error { + ctx.Response.Status = http.StatusOK + ctx.Response.Body = expectedBody + ctx.Response.AddHeader(http.Header{Name: "Content-Type", Value: "text/plain"}) + return nil + })) + if err != nil { + t.Fatalf("failed setting up handler: %v", err) + } + ctx, cfunc := context.WithCancel(context.Background()) + servClosed := make(chan bool, 1) + + // Start server in background + go func() { + defer func() { servClosed <- true }() + err := httpServer.StartServing(ctx) + if err != nil { + t.Errorf("failed to serve: %v", err) + return + } + }() + // Wait for server to start + time.Sleep(200 * time.Millisecond) + + conn, err := net.Dial("tcp", "localhost:8092") + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + // First request with keep-alive + request1 := "GET /test HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Connection: keep-alive\r\n\r\n" + + // Second request with close to terminate connection + request2 := "GET /test HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Connection: close\r\n\r\n" + + conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Write([]byte(request1)) + if err != nil { + t.Fatalf("Failed to write first request: %v", err) + } + + time.Sleep(500 * time.Millisecond) // brief pause + + _, err = conn.Write([]byte(request2)) + if err != nil { + t.Fatalf("Failed to write second request: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + respBuf := new(strings.Builder) + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + respBuf.WriteString(scanner.Text()) + respBuf.WriteString("\n") + } + if err := scanner.Err(); err != nil { + t.Fatalf("Error reading response: %v", err) + } + + response := respBuf.String() + count := strings.Count(response, "HTTP/1.1") + if count < 2 { + t.Fatalf("expected 2 HTTP responses, got %d\nFull response:\n%s", count, response) + } + + t.Logf("✅ Keep-Alive is working correctly: received %d responses", count) + cfunc() + _ = <-servClosed +}