diff --git a/memhttp.go b/memhttp.go index c1cd520..ac02ee0 100644 --- a/memhttp.go +++ b/memhttp.go @@ -23,7 +23,8 @@ type Server struct { certificate *x509.Certificate // for client url string disableHTTP2 bool - serveErr chan error + serveWG sync.WaitGroup + serveErr error cleanupContext func() (context.Context, context.CancelFunc) } @@ -62,24 +63,26 @@ func New(handler http.Handler, opts ...Option) (*Server, error) { lis = tls.NewListener(mlis, server.TLSConfig) } - serveErr := make(chan error, 1) - go func() { - serveErr <- server.Serve(lis) - }() - scheme := "https://" if cfg.DisableTLS { scheme = "http://" } - return &Server{ + s := &Server{ server: server, listener: mlis, certificate: clientCert, url: scheme + mlis.Addr().String(), disableHTTP2: cfg.DisableHTTP2, - serveErr: serveErr, cleanupContext: cfg.CleanupContext, - }, nil + } + + s.serveWG.Add(1) + go func() { + defer s.serveWG.Done() + s.serveErr = s.server.Serve(lis) + }() + + return s, nil } // Transport returns an [http.Transport] configured to use in-memory pipes @@ -153,8 +156,10 @@ func (s *Server) RegisterOnShutdown(f func()) { } func (s *Server) listenErr() error { - if err := <-s.serveErr; err != nil && !errors.Is(err, http.ErrServerClosed) { - return err + s.serveWG.Wait() + + if !errors.Is(s.serveErr, http.ErrServerClosed) { + return s.serveErr } return nil } diff --git a/memhttp_test.go b/memhttp_test.go index fc204e5..cea58a7 100644 --- a/memhttp_test.go +++ b/memhttp_test.go @@ -84,6 +84,20 @@ func TestRegisterOnShutdown(t *testing.T) { } } +func TestClose(t *testing.T) { + t.Parallel() + srv, err := memhttp.New(&greeter{}) + attest.Ok(t, err) + err = srv.Close() + attest.Ok(t, err) + err = srv.Close() + attest.Ok(t, err) + err = srv.Shutdown(context.Background()) + attest.Ok(t, err) + err = srv.Cleanup() + attest.Ok(t, err) +} + func Example() { hello := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, world!")