Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 30 additions & 23 deletions backend/xray/xray.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ type Xray struct {
core *Core
handler *api.XrayHandler
configPath string
ctx context.Context
cancelFunc context.CancelFunc
mu sync.Mutex
mu sync.RWMutex
}

func NewXray(ctx context.Context, port int, executablePath, assetsPath, configPath string) (*Xray, error) {
Expand All @@ -48,7 +47,7 @@ func NewXray(ctx context.Context, port int, executablePath, assetsPath, configPa

xCtx, xCancel := context.WithCancel(context.Background())

xray := &Xray{configPath: configAbsolutePath, ctx: xCtx, cancelFunc: xCancel}
xray := &Xray{configPath: configAbsolutePath, cancelFunc: xCancel}

start := time.Now()

Expand Down Expand Up @@ -94,7 +93,7 @@ func NewXray(ctx context.Context, port int, executablePath, assetsPath, configPa
return nil, err
}
xray.setHandler(handler)
go xray.checkXrayHealth()
go xray.checkXrayHealth(xCtx)

return xray, nil
}
Expand All @@ -106,8 +105,8 @@ func (x *Xray) setConfig(config *Config) {
}

func (x *Xray) getConfig() *Config {
x.mu.Lock()
defer x.mu.Unlock()
x.mu.RLock()
defer x.mu.RUnlock()
return x.config
}

Expand All @@ -118,26 +117,26 @@ func (x *Xray) setCore(core *Core) {
}

func (x *Xray) getCore() *Core {
x.mu.Lock()
defer x.mu.Unlock()
x.mu.RLock()
defer x.mu.RUnlock()
return x.core
}

func (x *Xray) GetCore() backend.Core {
x.mu.Lock()
defer x.mu.Unlock()
x.mu.RLock()
defer x.mu.RUnlock()
return x.core
}

func (x *Xray) GetLogs() chan string {
x.mu.Lock()
defer x.mu.Unlock()
x.mu.RLock()
defer x.mu.RUnlock()
return x.core.GetLogs()
}

func (x *Xray) GetVersion() string {
x.mu.Lock()
defer x.mu.Unlock()
x.mu.RLock()
defer x.mu.RUnlock()
return x.core.GetVersion()
}

Expand All @@ -148,8 +147,8 @@ func (x *Xray) setHandler(handler *api.XrayHandler) {
}

func (x *Xray) getHandler() *api.XrayHandler {
x.mu.Lock()
defer x.mu.Unlock()
x.mu.RLock()
defer x.mu.RUnlock()
return x.handler
}

Expand Down Expand Up @@ -213,8 +212,8 @@ func (x *Xray) GetInboundStats(ctx context.Context, tag string, reset bool) (*co
}

func (x *Xray) GenerateConfigFile() error {
x.mu.Lock()
defer x.mu.Unlock()
x.mu.RLock()
defer x.mu.RUnlock()

var prettyJSON bytes.Buffer

Expand Down Expand Up @@ -271,21 +270,29 @@ Loop:
return nil
}

func (x *Xray) checkXrayHealth() {
func (x *Xray) checkXrayHealth(baseCtx context.Context) {
for {
select {
case <-x.ctx.Done():
case <-baseCtx.Done():
return
default:
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
if _, err := x.GetSysStats(ctx); err != nil {
ctx, cancel := context.WithTimeout(baseCtx, time.Second*2)
_, err := x.GetSysStats(ctx)
cancel() // Always call cancel to avoid context leak

if err != nil {
if errors.Is(err, context.Canceled) {
// Context was canceled due to x.ctx cancellation
return // Exit gracefully
}

// Handle other errors by attempting restart
if err = x.Restart(); err != nil {
nodeLogger.Log(nodeLogger.LogError, err.Error())
} else {
nodeLogger.Log(nodeLogger.LogInfo, "xray restarted")
}
}
cancel()
}
time.Sleep(time.Second * 2)
}
Expand Down
69 changes: 30 additions & 39 deletions controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,52 @@ import (
const NodeVersion = "1.0.0"

type Service interface {
StopService()
Disconnect()
}

type Controller struct {
backend backend.Backend
sessionID uuid.UUID
apiPort int
clientIP string
lastRequest time.Time
stats *common.SystemStatsResponse
aliveCancel context.CancelFunc
statsCancel context.CancelFunc
mu sync.Mutex
cancelFunc context.CancelFunc
mu sync.RWMutex
}

func NewController() *Controller {
c := &Controller{
sessionID: uuid.Nil,
apiPort: tools.FindFreePort(),
}
c.startJobs()
return c
func (c *Controller) Init() {
c.mu.Lock()
defer c.mu.Unlock()
c.sessionID = uuid.Nil
c.apiPort = tools.FindFreePort()
_, c.cancelFunc = context.WithCancel(context.Background())
}

func (c *Controller) GetSessionID() uuid.UUID {
c.mu.Lock()
defer c.mu.Unlock()
c.mu.RLock()
defer c.mu.RUnlock()
return c.sessionID
}

func (c *Controller) Connect(keepAlive uint64) {
func (c *Controller) Connect(ip string, keepAlive uint64) {
c.mu.Lock()
defer c.mu.Unlock()
c.sessionID = uuid.New()
c.lastRequest = time.Now()
c.clientIP = ip

ctx, cancel := context.WithCancel(context.Background())
c.aliveCancel = cancel
c.cancelFunc = cancel
go c.recordSystemStats(ctx)
if keepAlive > 0 {
go c.keepAliveTracker(ctx, time.Duration(keepAlive)*time.Second)
}
}

func (c *Controller) Disconnect() {
c.cancelFunc()

c.mu.Lock()
defer c.mu.Unlock()

Expand All @@ -74,8 +77,13 @@ func (c *Controller) Disconnect() {
c.apiPort = apiPort

c.sessionID = uuid.Nil
c.clientIP = ""
}

c.aliveCancel()
func (c *Controller) GetIP() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.clientIP
}

func (c *Controller) NewRequest() {
Expand Down Expand Up @@ -104,8 +112,8 @@ func (c *Controller) StartBackend(ctx context.Context, backendType common.Backen
}

func (c *Controller) GetBackend() backend.Backend {
c.mu.Lock()
defer c.mu.Unlock()
c.mu.RLock()
defer c.mu.RUnlock()
return c.backend
}

Expand All @@ -115,9 +123,9 @@ func (c *Controller) keepAliveTracker(ctx context.Context, keepAlive time.Durati
case <-ctx.Done():
break
default:
c.mu.Lock()
c.mu.RLock()
lastRequest := c.lastRequest
c.mu.Unlock()
c.mu.RUnlock()
if time.Since(lastRequest) >= keepAlive {
log.Println("disconnect automatically due to keep alive timeout")
c.Disconnect()
Expand Down Expand Up @@ -147,8 +155,8 @@ func (c *Controller) recordSystemStats(ctx context.Context) {
}

func (c *Controller) GetStats() *common.SystemStatsResponse {
c.mu.Lock()
defer c.mu.Unlock()
c.mu.RLock()
defer c.mu.RUnlock()
return c.stats
}

Expand All @@ -173,20 +181,3 @@ func (c *Controller) BaseInfoResponse(includeID bool, extra string) *common.Base

return response
}

func (c *Controller) startJobs() {
ctx, cancel := context.WithCancel(context.Background())
c.mu.Lock()
defer c.mu.Unlock()
c.statsCancel = cancel
go c.recordSystemStats(ctx)
}

func (c *Controller) StopJobs() {
c.mu.Lock()
c.statsCancel()
c.mu.Unlock()

c.Disconnect()

}
18 changes: 9 additions & 9 deletions controller/rest/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func (s *Service) Base(w http.ResponseWriter, _ *http.Request) {
common.SendProtoResponse(w, s.controller.BaseInfoResponse(false, ""))
common.SendProtoResponse(w, s.BaseInfoResponse(false, ""))
}

func (s *Service) Start(w http.ResponseWriter, r *http.Request) {
Expand All @@ -29,26 +29,26 @@ func (s *Service) Start(w http.ResponseWriter, r *http.Request) {
return
}

if s.controller.GetBackend() != nil {
if s.GetBackend() != nil {
log.Println("New connection from ", ip, " core control access was taken away from previous client.")
s.disconnect()
s.Disconnect()
}

s.connect(ip, keepAlive)
s.Connect(ip, keepAlive)

log.Println(ip, " connected, Session ID = ", s.controller.GetSessionID())
log.Println(ip, " connected, Session ID = ", s.GetSessionID())

if err = s.controller.StartBackend(ctx, backendType); err != nil {
if err = s.StartBackend(ctx, backendType); err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}

common.SendProtoResponse(w, s.controller.BaseInfoResponse(true, ""))
common.SendProtoResponse(w, s.BaseInfoResponse(true, ""))
}

func (s *Service) Stop(w http.ResponseWriter, _ *http.Request) {
log.Println(s.GetIP(), " disconnected, Session ID = ", s.controller.GetSessionID())
s.disconnect()
log.Println(s.GetIP(), " disconnected, Session ID = ", s.GetSessionID())
s.Disconnect()

common.SendProtoResponse(w, &common.Empty{})
}
Expand Down
2 changes: 1 addition & 1 deletion controller/rest/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func (s *Service) GetLogs(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

logChan := s.controller.GetBackend().GetLogs()
logChan := s.GetBackend().GetLogs()

for {
select {
Expand Down
4 changes: 2 additions & 2 deletions controller/rest/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func (s *Service) checkSessionIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check ip
clientIP := s.GetIP()
clientID := s.controller.GetSessionID()
clientID := s.GetSessionID()
if clientIP == "" || clientID == uuid.Nil {
http.Error(w, "please connect first", http.StatusTooEarly)
return
Expand Down Expand Up @@ -61,7 +61,7 @@ func (s *Service) checkSessionIDMiddleware(next http.Handler) http.Handler {

func (s *Service) checkBackendMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
back := s.controller.GetBackend()
back := s.GetBackend()
if back == nil {
http.Error(w, "backend not initialized", http.StatusInternalServerError)
return
Expand Down
2 changes: 1 addition & 1 deletion controller/rest/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestRESTConnection(t *testing.T) {
if err != nil {
t.Fatalf("Failed to start HTTP listener: %v", err)
}
defer s.StopService()
defer s.Disconnect()

creds, err := tools.LoadTLSCredentials(sslClientCertFile, sslClientKeyFile, sslCertFile, true)
if err != nil {
Expand Down
Loading