diff --git a/cmd/centralserver/main.go b/cmd/centralserver/main.go index 332e935..5965e9a 100644 --- a/cmd/centralserver/main.go +++ b/cmd/centralserver/main.go @@ -10,6 +10,7 @@ import ( "net" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -141,13 +142,22 @@ func (cs *centralServer) handleSOCKS5Passthrough(clientConn net.Conn, firstByte } // handleFrameConn reads framed packets from a tunnel connection. +// When this function returns (source TCP died), it cleans up all +// connStates that had this as their only source. func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAddr string) { log.Printf("[central] frame connection from %s", remoteAddr) + // Track which ConnIDs this source served + servedIDs := make(map[uint32]bool) + + defer func() { + // Source TCP died — clean up connStates that only had this source + cs.cleanupSource(conn, servedIDs, remoteAddr) + }() + // Read remaining header bytes (we already read 1) var hdrRest [tunnel.HeaderSize - 1]byte if _, err := io.ReadFull(conn, hdrRest[:]); err != nil { - log.Printf("[central] %s: frame header read error: %v", remoteAddr, err) return } @@ -157,21 +167,75 @@ func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAd firstFrame := cs.parseHeader(fullHdr, conn, remoteAddr) if firstFrame != nil { + servedIDs[firstFrame.ConnID] = true cs.dispatchFrame(firstFrame, conn) } for { frame, err := tunnel.ReadFrame(conn) if err != nil { - if err != io.EOF { + if err != io.EOF && !isClosedConnErr(err) { log.Printf("[central] %s: read error: %v", remoteAddr, err) } return } + servedIDs[frame.ConnID] = true cs.dispatchFrame(frame, conn) } } +// cleanupSource removes a dead source connection from all connStates. +// If a connState has no remaining sources, it is fully cleaned up. +func (cs *centralServer) cleanupSource(deadSource net.Conn, servedIDs map[uint32]bool, remoteAddr string) { + cs.mu.Lock() + defer cs.mu.Unlock() + + cleaned := 0 + for connID := range servedIDs { + state, ok := cs.conns[connID] + if !ok { + continue + } + + state.mu.Lock() + // Remove dead source from sources list + for i, src := range state.sources { + if src == deadSource { + state.sources = append(state.sources[:i], state.sources[i+1:]...) + break + } + } + + // If no sources left, fully clean up this connState + if len(state.sources) == 0 { + state.mu.Unlock() + if state.cancel != nil { + state.cancel() + } + if state.target != nil { + state.target.Close() + } + delete(cs.conns, connID) + cleaned++ + } else { + state.mu.Unlock() + } + } + + if cleaned > 0 { + log.Printf("[central] %s: source disconnected, cleaned %d orphaned connections", remoteAddr, cleaned) + } +} + +func isClosedConnErr(err error) bool { + if err == nil { + return false + } + s := err.Error() + return strings.Contains(s, "use of closed network connection") || + strings.Contains(s, "connection reset by peer") +} + func (cs *centralServer) parseHeader(hdr [tunnel.HeaderSize]byte, conn net.Conn, remoteAddr string) *tunnel.Frame { length := binary.BigEndian.Uint16(hdr[9:11]) if length > tunnel.MaxPayloadSize { @@ -380,7 +444,7 @@ func (cs *centralServer) relayUpstreamToTunnel(ctx context.Context, connID uint3 cs.sendFrame(connID, frame) } if err != nil { - if err != io.EOF { + if err != io.EOF && !isClosedConnErr(err) { log.Printf("[central] conn=%d: upstream read error: %v", connID, err) } return diff --git a/internal/tunnel/pool.go b/internal/tunnel/pool.go index a007597..34b831c 100644 --- a/internal/tunnel/pool.go +++ b/internal/tunnel/pool.go @@ -1,52 +1,98 @@ package tunnel import ( + "fmt" "io" "log" "net" + "strings" "sync" + "sync/atomic" "time" "github.com/ParsaKSH/SlipStream-Plus/internal/engine" ) -// TunnelPool manages per-user-connection tunnel connections. -// Instead of maintaining persistent connections (which degrade over QUIC), -// each user connection gets its own fresh TCP connections to instances. -// This mirrors the proven connection-level approach but with multiplexing. +// writeTimeout prevents writes from blocking forever on stalled connections. +const writeTimeout = 10 * time.Second + +// readTimeout is applied per ReadFrame call. If no frame arrives within this +// duration, we check if the tunnel is stale. +const readTimeout = 30 * time.Second + +// staleThreshold: if we've sent data recently but haven't received anything +// in this long, the connection is considered half-dead. +const staleThreshold = 20 * time.Second + +// TunnelConn wraps a persistent TCP connection to a single instance. +type TunnelConn struct { + inst *engine.Instance + mu sync.Mutex + conn net.Conn + writeMu sync.Mutex + closed bool + + lastRead atomic.Int64 // unix millis of last successful read + lastWrite atomic.Int64 // unix millis of last successful write +} + +// TunnelPool manages ONE persistent connection per healthy instance. type TunnelPool struct { mgr *engine.Manager + mu sync.RWMutex + tunnels map[int]*TunnelConn handlers sync.Map // ConnID (uint32) → chan *Frame stopCh chan struct{} + wg sync.WaitGroup } -// NewTunnelPool creates a new tunnel pool. func NewTunnelPool(mgr *engine.Manager) *TunnelPool { return &TunnelPool{ - mgr: mgr, - stopCh: make(chan struct{}), + mgr: mgr, + tunnels: make(map[int]*TunnelConn), + stopCh: make(chan struct{}), } } -// Start initializes the pool. func (p *TunnelPool) Start() { - log.Printf("[tunnel-pool] started (per-connection mode)") + p.refreshConnections() + + p.wg.Add(1) + go func() { + defer p.wg.Done() + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-p.stopCh: + return + case <-ticker.C: + p.refreshConnections() + } + } + }() + + log.Printf("[tunnel-pool] started (stale_threshold=%s)", staleThreshold) } -// Stop signals all operations to cease. func (p *TunnelPool) Stop() { close(p.stopCh) + p.mu.Lock() + for _, tc := range p.tunnels { + tc.close() + } + p.tunnels = make(map[int]*TunnelConn) + p.mu.Unlock() + p.wg.Wait() log.Printf("[tunnel-pool] stopped") } -// RegisterConn creates a channel for receiving frames for a given ConnID. func (p *TunnelPool) RegisterConn(connID uint32) chan *Frame { ch := make(chan *Frame, 256) p.handlers.Store(connID, ch) return ch } -// UnregisterConn removes the handler for a ConnID. func (p *TunnelPool) UnregisterConn(connID uint32) { if v, ok := p.handlers.LoadAndDelete(connID); ok { ch := v.(chan *Frame) @@ -54,14 +100,75 @@ func (p *TunnelPool) UnregisterConn(connID uint32) { } } -// DialInstance creates a fresh TCP connection to an instance and starts -// a read loop that dispatches incoming frames to registered handlers. -// Returns the connection and a cleanup function. -// The caller is responsible for calling cleanup when done. -func (p *TunnelPool) DialInstance(inst *engine.Instance) (net.Conn, func(), error) { +func (p *TunnelPool) SendFrame(instID int, f *Frame) error { + p.mu.RLock() + tc, ok := p.tunnels[instID] + p.mu.RUnlock() + + if !ok { + return fmt.Errorf("no tunnel for instance %d", instID) + } + + return tc.writeFrame(f) +} + +// refreshConnections reconnects dead/stale tunnels and adds new healthy instances. +func (p *TunnelPool) refreshConnections() { + healthy := p.mgr.HealthyInstances() + nowMs := time.Now().UnixMilli() + + p.mu.Lock() + defer p.mu.Unlock() + + activeIDs := make(map[int]bool) + for _, inst := range healthy { + activeIDs[inst.ID()] = true + } + + for id, tc := range p.tunnels { + shouldRemove := false + + if !activeIDs[id] || tc.closed { + shouldRemove = true + } else { + // Detect half-dead connections: + // If we wrote recently but haven't read in staleThreshold, connection is dead. + lastW := tc.lastWrite.Load() + lastR := tc.lastRead.Load() + if lastW > 0 && (nowMs-lastR) > staleThreshold.Milliseconds() { + log.Printf("[tunnel-pool] instance %d: stale connection detected (last_read=%dms ago, last_write=%dms ago), recycling", + id, nowMs-lastR, nowMs-lastW) + shouldRemove = true + } + } + + if shouldRemove { + tc.close() + delete(p.tunnels, id) + } + } + + for _, inst := range healthy { + if inst.Config.Mode == "ssh" { + continue + } + if _, exists := p.tunnels[inst.ID()]; exists { + continue + } + tc, err := p.connectInstance(inst) + if err != nil { + continue + } + p.tunnels[inst.ID()] = tc + log.Printf("[tunnel-pool] connected to instance %d (%s:%d)", + inst.ID(), inst.Config.Domain, inst.Config.Port) + } +} + +func (p *TunnelPool) connectInstance(inst *engine.Instance) (*TunnelConn, error) { conn, err := inst.Dial() if err != nil { - return nil, nil, err + return nil, err } if tc, ok := conn.(*net.TCPConn); ok { @@ -70,56 +177,103 @@ func (p *TunnelPool) DialInstance(inst *engine.Instance) (net.Conn, func(), erro tc.SetNoDelay(true) } - closed := make(chan struct{}) - cleanup := func() { + now := time.Now().UnixMilli() + tunnel := &TunnelConn{ + inst: inst, + conn: conn, + } + tunnel.lastRead.Store(now) + tunnel.lastWrite.Store(0) // no writes yet + + p.wg.Add(1) + go func() { + defer p.wg.Done() + p.readLoop(tunnel) + }() + + return tunnel, nil +} + +func (p *TunnelPool) readLoop(tc *TunnelConn) { + for { select { - case <-closed: - return // already closed + case <-p.stopCh: + return default: - close(closed) - conn.Close() } - } - // Read loop: dispatch incoming frames to handlers - go func() { - defer conn.Close() - for { - select { - case <-p.stopCh: - return - case <-closed: - return - default: + // Set read deadline so we don't block forever on dead connections. + // If timeout fires, we loop back and try again — refreshConnections + // will detect staleness and force close if needed. + tc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + frame, err := ReadFrame(tc.conn) + if err != nil { + if isTimeoutErr(err) { + // Read timed out — not necessarily dead, just no data. + // refreshConnections will check staleness. + continue } - - frame, err := ReadFrame(conn) - if err != nil { - if err != io.EOF { - log.Printf("[tunnel-pool] read from instance %d: %v", inst.ID(), err) - } - return + if err != io.EOF && !isClosedErr(err) { + log.Printf("[tunnel-pool] instance %d read error: %v", tc.inst.ID(), err) } + tc.close() + return + } - if v, ok := p.handlers.Load(frame.ConnID); ok { - ch := v.(chan *Frame) - select { - case ch <- frame: - default: - // Handler buffer full, drop - } + tc.lastRead.Store(time.Now().UnixMilli()) + + if v, ok := p.handlers.Load(frame.ConnID); ok { + ch := v.(chan *Frame) + select { + case ch <- frame: + default: + // Buffer full — drop frame silently } } - }() - - return conn, cleanup, nil + } } -// SendFrame writes a frame directly to a connection. -// This is a convenience wrapper used by PacketSplitter. -func SendFrameTo(conn net.Conn, f *Frame) error { - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - err := WriteFrame(conn, f) - conn.SetWriteDeadline(time.Time{}) +func (tc *TunnelConn) writeFrame(f *Frame) error { + tc.writeMu.Lock() + defer tc.writeMu.Unlock() + + if tc.closed { + return fmt.Errorf("tunnel closed") + } + + tc.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + err := WriteFrame(tc.conn, f) + tc.conn.SetWriteDeadline(time.Time{}) + + if err == nil { + tc.lastWrite.Store(time.Now().UnixMilli()) + } return err } + +func (tc *TunnelConn) close() { + tc.mu.Lock() + defer tc.mu.Unlock() + if !tc.closed { + tc.closed = true + tc.conn.Close() + } +} + +func isClosedErr(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "use of closed network connection") || + strings.Contains(err.Error(), "connection reset by peer") +} + +func isTimeoutErr(err error) bool { + if err == nil { + return false + } + if ne, ok := err.(net.Error); ok { + return ne.Timeout() + } + return false +} diff --git a/internal/tunnel/splitter.go b/internal/tunnel/splitter.go index 7000f48..456343a 100644 --- a/internal/tunnel/splitter.go +++ b/internal/tunnel/splitter.go @@ -3,8 +3,6 @@ package tunnel import ( "context" "io" - "log" - "net" "sync" "sync/atomic" "time" @@ -15,24 +13,14 @@ import ( // PacketSplitter distributes data from a client connection across multiple // instances at the packet/chunk level, and reassembles reverse-direction // frames back to the client. -// -// Each PacketSplitter creates its own fresh TCP connections to instances -// (instead of using persistent pool connections). This gives each user -// connection fresh QUIC streams, avoiding the degradation that happens -// with long-lived multiplexed streams through DNS tunnels. type PacketSplitter struct { connID uint32 pool *TunnelPool instances []*engine.Instance chunkSize int - incoming chan *Frame // frames coming back from instances for this ConnID + incoming chan *Frame - txSeq atomic.Uint32 // next send sequence number - - // Per-instance connections: fresh TCP for each user connection - connMu sync.Mutex - instConns map[int]net.Conn // instance ID → TCP connection - cleanups []func() // cleanup functions for all connections + txSeq atomic.Uint32 // Weighted round-robin state mu sync.Mutex @@ -41,8 +29,6 @@ type PacketSplitter struct { counter int } -// NewPacketSplitter creates a splitter for one client connection. -// It dials fresh TCP connections to each instance. func NewPacketSplitter(connID uint32, pool *TunnelPool, instances []*engine.Instance, chunkSize int) *PacketSplitter { ps := &PacketSplitter{ connID: connID, @@ -50,48 +36,23 @@ func NewPacketSplitter(connID uint32, pool *TunnelPool, instances []*engine.Inst instances: instances, chunkSize: chunkSize, incoming: pool.RegisterConn(connID), - instConns: make(map[int]net.Conn), } ps.recalcWeights() - - // Dial fresh connections to all instances - for _, inst := range instances { - conn, cleanup, err := pool.DialInstance(inst) - if err != nil { - log.Printf("[splitter] conn=%d: dial instance %d failed: %v", - connID, inst.ID(), err) - continue - } - ps.instConns[inst.ID()] = conn - ps.cleanups = append(ps.cleanups, cleanup) - } - return ps } -// Close sends FIN to all instances, closes all connections, and unregisters. func (ps *PacketSplitter) Close() { fin := &Frame{ ConnID: ps.connID, SeqNum: ps.txSeq.Add(1) - 1, Flags: FlagFIN, } - - ps.connMu.Lock() - for _, conn := range ps.instConns { - SendFrameTo(conn, fin) + for _, inst := range ps.instances { + ps.pool.SendFrame(inst.ID(), fin) } - ps.connMu.Unlock() - - // Unregister handler first, then close connections ps.pool.UnregisterConn(ps.connID) - - for _, cleanup := range ps.cleanups { - cleanup() - } } -// SendSYN sends a SYN frame (with target address) through all instances. func (ps *PacketSplitter) SendSYN(atyp byte, addr []byte, port []byte) error { payload := EncodeSYNPayload(atyp, addr, port) frame := &Frame{ @@ -101,40 +62,19 @@ func (ps *PacketSplitter) SendSYN(atyp byte, addr []byte, port []byte) error { Payload: payload, } - ps.connMu.Lock() - defer ps.connMu.Unlock() - sent := 0 - for id, conn := range ps.instConns { - if err := SendFrameTo(conn, frame); err != nil { - log.Printf("[splitter] conn=%d: SYN to instance %d failed: %v", - ps.connID, id, err) + for _, inst := range ps.instances { + if err := ps.pool.SendFrame(inst.ID(), frame); err != nil { continue } sent++ } - if sent == 0 { return io.ErrClosedPipe } return nil } -// sendFrame sends a frame through a specific instance's connection. -func (ps *PacketSplitter) sendFrame(instID int, f *Frame) error { - ps.connMu.Lock() - conn, ok := ps.instConns[instID] - ps.connMu.Unlock() - - if !ok { - return io.ErrClosedPipe - } - - return SendFrameTo(conn, f) -} - -// RelayClientToUpstream reads from the client, splits into chunks, and sends -// through instances. Returns total bytes transferred. func (ps *PacketSplitter) RelayClientToUpstream(ctx context.Context, client io.Reader) int64 { buf := make([]byte, ps.chunkSize) var totalBytes int64 @@ -152,7 +92,6 @@ func (ps *PacketSplitter) RelayClientToUpstream(ctx context.Context, client io.R inst := ps.pickInstance() if inst == nil { - log.Printf("[splitter] conn=%d: no healthy instance available", ps.connID) return totalBytes } @@ -164,28 +103,20 @@ func (ps *PacketSplitter) RelayClientToUpstream(ctx context.Context, client io.R } copy(frame.Payload, buf[:n]) - if sendErr := ps.sendFrame(inst.ID(), frame); sendErr != nil { - log.Printf("[splitter] conn=%d: send to instance %d failed: %v", - ps.connID, inst.ID(), sendErr) - // Try another instance + if sendErr := ps.pool.SendFrame(inst.ID(), frame); sendErr != nil { inst2 := ps.pickInstanceExcluding(inst.ID()) if inst2 != nil { - ps.sendFrame(inst2.ID(), frame) + ps.pool.SendFrame(inst2.ID(), frame) } } } if err != nil { - if err != io.EOF { - log.Printf("[splitter] conn=%d: client read error: %v", ps.connID, err) - } return totalBytes } } } -// RelayUpstreamToClient reads frames from the incoming channel (reverse direction), -// reorders them, and writes to the client. Returns total bytes transferred. func (ps *PacketSplitter) RelayUpstreamToClient(ctx context.Context, client io.Writer) int64 { reorderer := NewReorderer() var totalBytes int64 @@ -200,7 +131,6 @@ func (ps *PacketSplitter) RelayUpstreamToClient(ctx context.Context, client io.W } if frame.IsFIN() || frame.IsRST() { - // Drain remaining ordered frames for { data := reorderer.Next() if data == nil { @@ -212,19 +142,15 @@ func (ps *PacketSplitter) RelayUpstreamToClient(ctx context.Context, client io.W return totalBytes } - // Skip control frames (ACK, SYN) — only process data if frame.IsACK() || frame.IsSYN() { continue } - - // Skip empty payloads if len(frame.Payload) == 0 { continue } reorderer.Insert(frame.SeqNum, frame.Payload) - // Write all available in-order data for { data := reorderer.Next() if data == nil { @@ -240,7 +166,6 @@ func (ps *PacketSplitter) RelayUpstreamToClient(ctx context.Context, client io.W } } -// recalcWeights updates the per-instance weights based on latency. func (ps *PacketSplitter) recalcWeights() { ps.mu.Lock() defer ps.mu.Unlock() @@ -267,7 +192,6 @@ func (ps *PacketSplitter) recalcWeights() { } } -// pickInstance selects the next instance using weighted round-robin. func (ps *PacketSplitter) pickInstance() *engine.Instance { ps.mu.Lock() defer ps.mu.Unlock() @@ -284,40 +208,25 @@ func (ps *PacketSplitter) pickInstance() *engine.Instance { ps.counter++ inst := ps.instances[ps.current] - // Check both health AND that we have a connection - ps.connMu.Lock() - _, hasConn := ps.instConns[inst.ID()] - ps.connMu.Unlock() - - if inst.IsHealthy() && hasConn { + if inst.IsHealthy() { return inst } ps.counter = 0 ps.current = (ps.current + 1) % len(ps.instances) } - // Fallback for _, inst := range ps.instances { - ps.connMu.Lock() - _, hasConn := ps.instConns[inst.ID()] - ps.connMu.Unlock() - if inst.IsHealthy() && hasConn { + if inst.IsHealthy() { return inst } } return nil } -// pickInstanceExcluding picks any healthy instance except the excluded one. func (ps *PacketSplitter) pickInstanceExcluding(excludeID int) *engine.Instance { for _, inst := range ps.instances { if inst.ID() != excludeID && inst.IsHealthy() { - ps.connMu.Lock() - _, hasConn := ps.instConns[inst.ID()] - ps.connMu.Unlock() - if hasConn { - return inst - } + return inst } } return nil @@ -330,7 +239,6 @@ type Reorderer struct { timeout time.Duration } -// NewReorderer creates a new frame reorderer starting from sequence 0. func NewReorderer() *Reorderer { return &Reorderer{ nextSeq: 0, @@ -339,7 +247,6 @@ func NewReorderer() *Reorderer { } } -// NewReordererAt creates a reorderer starting from a specific sequence number. func NewReordererAt(startSeq uint32) *Reorderer { return &Reorderer{ nextSeq: startSeq, @@ -348,7 +255,6 @@ func NewReordererAt(startSeq uint32) *Reorderer { } } -// Insert adds a frame to the reorder buffer. func (r *Reorderer) Insert(seq uint32, data []byte) { if seq < r.nextSeq { return @@ -356,7 +262,6 @@ func (r *Reorderer) Insert(seq uint32, data []byte) { r.buffer[seq] = data } -// Next returns the next in-order payload, or nil if waiting for a gap. func (r *Reorderer) Next() []byte { data, ok := r.buffer[r.nextSeq] if !ok { @@ -367,12 +272,10 @@ func (r *Reorderer) Next() []byte { return data } -// Pending returns the number of buffered out-of-order frames. func (r *Reorderer) Pending() int { return len(r.buffer) } -// SkipGap advances past a missing sequence number. func (r *Reorderer) SkipGap() { r.nextSeq++ }