Skip to content
92 changes: 60 additions & 32 deletions cmd/dmsgweb/commands/dmsgweb.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/chen3feng/safecast"
"github.com/confiant-inc/go-socks5"
Expand Down Expand Up @@ -288,28 +288,28 @@ dmsgweb conf file detected: ` + dwcfg
if len(resolveDmsgAddr) == 0 && len(webPort) == 1 {
if len(rawTCP) > 0 && rawTCP[0] {
dlog.Debug("proxyTCPConn(-1)")
proxyTCPConn(-1)
proxyTCPConn(ctx, -1)
} else {
dlog.Debug("proxyHTTPConn(-1)")
proxyHTTPConn(-1)
proxyHTTPConn(ctx, -1)
}
} else {
for i := range resolveDmsgAddr {
wg.Add(1)
if rawTCP[i] {
dlog.Debug("proxyTCPConn(" + fmt.Sprintf("%v", i) + ")")
go proxyTCPConn(i)
go proxyTCPConn(ctx, i)
} else {
dlog.Debug("proxyHTTPConn(" + fmt.Sprintf("%v", i) + ")")
go proxyHTTPConn(i)
go proxyHTTPConn(ctx, i)
}
}
}
wg.Wait()
},
}

func proxyTCPConn(n int) {
func proxyTCPConn(ctx context.Context, n int) { //nolint:unparam
var thiswebport uint
if n == -1 {
thiswebport = webPort[0]
Expand Down Expand Up @@ -337,49 +337,56 @@ func proxyTCPConn(n int) {
defer ioutil.CloseQuietly(conn, dlog)
dp, ok := safecast.To[uint16](dmsgPorts[n])
if !ok {
dlog.Fatal("uint16 overflow when converting dmsg port")
dlog.WithError(fmt.Errorf("uint16 overflow for port %v", dmsgPorts[n])).
Warn("Failed to convert dmsg port")
return
}
dlog.Debug(fmt.Sprintf("Dialing %v:%v", dialPK[n].String(), dp))
dmsgConn, err := dmsgC.DialStream(context.Background(), dmsg.Addr{PK: dialPK[n], Port: dp}) //nolint
dmsgConn, err := dmsgC.DialStream(ctx, dmsg.Addr{PK: dialPK[n], Port: dp})
if err != nil {
dlog.WithError(err).Warn(fmt.Sprintf("Failed to dial dmsg address %v port %v", dialPK[n].String(), dmsgPorts[n]))
return
}

defer ioutil.CloseQuietly(dmsgConn, dlog)

var wg sync.WaitGroup
wg.Add(2)

done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done)
_, err := io.Copy(dmsgConn, conn)
if err != nil {
dlog.WithError(err).Warn("Error on io.Copy(dmsgConn, conn)")
dlog.WithError(err).Debug("io.Copy(dmsgConn, conn) ended")
}
}()

go func() {
defer wg.Done()
_, err := io.Copy(conn, dmsgConn)
if err != nil {
dlog.WithError(err).Warn("Error on io.Copy(conn, dmsgConn)")
}
}()
_, err = io.Copy(conn, dmsgConn)
if err != nil {
dlog.WithError(err).Debug("io.Copy(conn, dmsgConn) ended")
}

wg.Wait()
// Close both to unblock the goroutine's io.Copy.
if err := conn.Close(); err != nil {
dlog.WithError(err).Debug("Error closing client conn")
}
if err := dmsgConn.Close(); err != nil {
dlog.WithError(err).Debug("Error closing dmsg conn")
}
<-done
}(conn, n, dmsgC)
}
}

func proxyHTTPConn(n int) {
func proxyHTTPConn(ctx context.Context, n int) { //nolint:unparam
r := gin.New()

r.Use(gin.Recovery())

r.Use(loggingMiddleware())

r.Any("/*path", func(c *gin.Context) {
// Limit request body to 10MB to prevent resource exhaustion.
const maxBodySize = 10 << 20
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBodySize)

var urlStr string
if n > -1 {
urlStr = fmt.Sprintf("dmsg://%s%s", resolveDmsgAddr[n], c.Param("path"))
Expand All @@ -401,7 +408,7 @@ func proxyHTTPConn(n int) {
}

dlog.Debug(fmt.Sprintf("Proxying request: %s %s", c.Request.Method, urlStr))
req, err := http.NewRequest(c.Request.Method, urlStr, c.Request.Body)
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, urlStr, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to create HTTP request")
dlog.WithError(err).Warn("Failed to create HTTP request")
Expand Down Expand Up @@ -430,23 +437,44 @@ func proxyHTTPConn(n int) {

c.Status(resp.StatusCode)
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
c.String(http.StatusInternalServerError, "Failed to copy response body")
// Status header is already written; cannot override with 500.
// Just log the error.
dlog.WithError(err).Warn("Failed to copy response body")
}
})

var thiswebport uint
if n == -1 {
thiswebport = webPort[0]
} else {
thiswebport = webPort[n]
}

srv := &http.Server{
Addr: fmt.Sprintf(":%v", thiswebport),
Handler: r,
ReadHeaderTimeout: 5 * time.Second,
}

wg.Add(1)
go func() {
defer wg.Done()
var thiswebport uint
if n == -1 {
thiswebport = webPort[0]
} else {
thiswebport = webPort[n]
}
dlog.Debug(fmt.Sprintf("Serving http on: http://127.0.0.1:%v", thiswebport))
r.Run(":" + fmt.Sprintf("%v", thiswebport)) //nolint
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
dlog.WithError(err).Error("HTTP server error")
}
dlog.Debug(fmt.Sprintf("Stopped serving http on: http://127.0.0.1:%v", thiswebport))
}()

// Graceful shutdown on context cancellation.
go func() { //nolint:gosec // G118: context.Background is intentional — shutdown must outlive parent ctx
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
dlog.WithError(err).Warn("HTTP server shutdown error")
}
}()
}

const envfileLinux = //nolint unused
Expand Down
57 changes: 44 additions & 13 deletions cmd/dmsgweb/commands/dmsgwebsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ func proxyHTTPConnections(ctx context.Context, localPort uint, listener net.List
}
authRoute.Any("/*path", func(c *gin.Context) {
targetURL := fmt.Sprintf("http://127.0.0.1:%d%s?%s", localPort, c.Request.URL.Path, c.Request.URL.RawQuery)
parsed, err := url.Parse(targetURL)
if err != nil {
dlog.Errorf("failed to parse target URL %q: %v", targetURL, err)
c.String(http.StatusInternalServerError, "Bad target URL")
return
}
proxy := httputil.ReverseProxy{Director: func(req *http.Request) {
parsed, err := url.Parse(targetURL)
if err != nil {
dlog.Errorf("failed to parse target URL %q: %v", targetURL, err)
return
}
req.URL = parsed
req.Host = req.URL.Host
}}
Expand Down Expand Up @@ -211,12 +212,16 @@ func proxyHTTPConnections(ctx context.Context, localPort uint, listener net.List
}
}

// maxTCPConns is the maximum number of concurrent TCP proxy connections.
const maxTCPConns = 256

func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Listener) {
// To track active connections for cleanup
var connWg sync.WaitGroup
connChan := make(chan net.Conn)
activeConns := make(map[net.Conn]struct{})
connMutex := &sync.Mutex{} // Protect access to activeConns
sem := make(chan struct{}, maxTCPConns)

// Goroutine to accept new connections
go func() {
Expand All @@ -241,11 +246,15 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste
select {
case <-ctx.Done():
dlog.Info("Shutting down TCP proxy connections...")
listener.Close() //nolint
if err := listener.Close(); err != nil {
dlog.WithError(err).Debug("Error closing TCP listener")
}

connMutex.Lock()
for conn := range activeConns {
conn.Close() //nolint
if err := conn.Close(); err != nil {
dlog.WithError(err).Debug("Error closing active connection")
}
}
connMutex.Unlock()

Expand All @@ -257,14 +266,30 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste
return
}

// Limit concurrent connections.
select {
case sem <- struct{}{}:
default:
dlog.Warn("Max TCP connections reached, rejecting connection")
if err := conn.Close(); err != nil {
dlog.WithError(err).Debug("Error closing rejected connection")
}
continue
}

connMutex.Lock()
activeConns[conn] = struct{}{}
connMutex.Unlock()

connWg.Add(1)
go func(dmsgConn net.Conn) {
defer func() { <-sem }()
defer connWg.Done()
defer dmsgConn.Close() //nolint
defer func() {
if err := dmsgConn.Close(); err != nil {
dlog.WithError(err).Debug("Error closing dmsg connection")
}
}()

localConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
if err != nil {
Expand All @@ -276,21 +301,27 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste

return
}
defer localConn.Close() //nolint

done := make(chan struct{})
go func() {
defer close(done)
_, err1 := io.Copy(dmsgConn, localConn)
if err1 != nil {
dlog.WithError(err1).Warn("Error on io.Copy(dmsgConn, localConn)")
dlog.WithError(err1).Debug("io.Copy(dmsgConn, localConn) ended")
}
}()
_, err2 := io.Copy(localConn, dmsgConn)
if err2 != nil {
dlog.WithError(err2).Warn("Error on io.Copy(localConn, dmsgConn)")
dlog.WithError(err2).Debug("io.Copy(localConn, dmsgConn) ended")
}
// Close both to unblock the goroutine
dmsgConn.Close() //nolint
localConn.Close() //nolint
if err := dmsgConn.Close(); err != nil {
dlog.WithError(err).Debug("Error closing dmsg conn")
}
if err := localConn.Close(); err != nil {
dlog.WithError(err).Debug("Error closing local conn")
}
<-done

connMutex.Lock()
delete(activeConns, dmsgConn)
Expand Down
Loading
Loading