From b6b094413596cd8319c3dc71c9d7f286b9745ddb Mon Sep 17 00:00:00 2001 From: M03ED <50927468+M03ED@users.noreply.github.com> Date: Fri, 14 Mar 2025 22:37:27 +0330 Subject: [PATCH 1/2] fix: keep alive bug --- controller/controller.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/controller/controller.go b/controller/controller.go index 897275a..0cb53aa 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -65,6 +65,8 @@ func (c *Controller) Disconnect() { c.mu.Lock() defer c.mu.Unlock() + c.aliveCancel() + if c.backend != nil { c.backend.Shutdown() } @@ -75,7 +77,6 @@ func (c *Controller) Disconnect() { c.sessionID = uuid.Nil - c.aliveCancel() } func (c *Controller) NewRequest() { From 070fc24521d2db15c89ea358105c0cc5c8da1933 Mon Sep 17 00:00:00 2001 From: M03ED <50927468+M03ED@users.noreply.github.com> Date: Sat, 15 Mar 2025 05:38:56 +0330 Subject: [PATCH 2/2] fix: prevent race condition's and remove unnecessary methods --- backend/xray/xray.go | 53 ++++++++++++++----------- controller/controller.go | 70 ++++++++++++++------------------ controller/rest/base.go | 18 ++++----- controller/rest/log.go | 2 +- controller/rest/middleware.go | 4 +- controller/rest/rest_test.go | 2 +- controller/rest/service.go | 75 ++++------------------------------- controller/rest/stats.go | 18 ++++----- controller/rest/user.go | 4 +- controller/rpc/base.go | 30 +++++++++++--- controller/rpc/log.go | 4 +- controller/rpc/middleware.go | 8 ++-- controller/rpc/rpc_test.go | 2 +- controller/rpc/service.go | 28 +++---------- controller/rpc/stats.go | 18 ++++----- controller/rpc/user.go | 11 ++--- main.go | 2 +- 17 files changed, 142 insertions(+), 207 deletions(-) diff --git a/backend/xray/xray.go b/backend/xray/xray.go index d3fcca0..47f853a 100644 --- a/backend/xray/xray.go +++ b/backend/xray/xray.go @@ -25,9 +25,8 @@ type Xray struct { core *Core handler *api.XrayHandler configPath string - ctx context.Context cancelFunc context.CancelFunc - mu sync.Mutex + mu sync.RWMutex } func NewXray(ctx context.Context, port int, executablePath, assetsPath, configPath string) (*Xray, error) { @@ -48,7 +47,7 @@ func NewXray(ctx context.Context, port int, executablePath, assetsPath, configPa xCtx, xCancel := context.WithCancel(context.Background()) - xray := &Xray{configPath: configAbsolutePath, ctx: xCtx, cancelFunc: xCancel} + xray := &Xray{configPath: configAbsolutePath, cancelFunc: xCancel} start := time.Now() @@ -94,7 +93,7 @@ func NewXray(ctx context.Context, port int, executablePath, assetsPath, configPa return nil, err } xray.setHandler(handler) - go xray.checkXrayHealth() + go xray.checkXrayHealth(xCtx) return xray, nil } @@ -106,8 +105,8 @@ func (x *Xray) setConfig(config *Config) { } func (x *Xray) getConfig() *Config { - x.mu.Lock() - defer x.mu.Unlock() + x.mu.RLock() + defer x.mu.RUnlock() return x.config } @@ -118,26 +117,26 @@ func (x *Xray) setCore(core *Core) { } func (x *Xray) getCore() *Core { - x.mu.Lock() - defer x.mu.Unlock() + x.mu.RLock() + defer x.mu.RUnlock() return x.core } func (x *Xray) GetCore() backend.Core { - x.mu.Lock() - defer x.mu.Unlock() + x.mu.RLock() + defer x.mu.RUnlock() return x.core } func (x *Xray) GetLogs() chan string { - x.mu.Lock() - defer x.mu.Unlock() + x.mu.RLock() + defer x.mu.RUnlock() return x.core.GetLogs() } func (x *Xray) GetVersion() string { - x.mu.Lock() - defer x.mu.Unlock() + x.mu.RLock() + defer x.mu.RUnlock() return x.core.GetVersion() } @@ -148,8 +147,8 @@ func (x *Xray) setHandler(handler *api.XrayHandler) { } func (x *Xray) getHandler() *api.XrayHandler { - x.mu.Lock() - defer x.mu.Unlock() + x.mu.RLock() + defer x.mu.RUnlock() return x.handler } @@ -213,8 +212,8 @@ func (x *Xray) GetInboundStats(ctx context.Context, tag string, reset bool) (*co } func (x *Xray) GenerateConfigFile() error { - x.mu.Lock() - defer x.mu.Unlock() + x.mu.RLock() + defer x.mu.RUnlock() var prettyJSON bytes.Buffer @@ -271,21 +270,29 @@ Loop: return nil } -func (x *Xray) checkXrayHealth() { +func (x *Xray) checkXrayHealth(baseCtx context.Context) { for { select { - case <-x.ctx.Done(): + case <-baseCtx.Done(): return default: - ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - if _, err := x.GetSysStats(ctx); err != nil { + ctx, cancel := context.WithTimeout(baseCtx, time.Second*2) + _, err := x.GetSysStats(ctx) + cancel() // Always call cancel to avoid context leak + + if err != nil { + if errors.Is(err, context.Canceled) { + // Context was canceled due to x.ctx cancellation + return // Exit gracefully + } + + // Handle other errors by attempting restart if err = x.Restart(); err != nil { nodeLogger.Log(nodeLogger.LogError, err.Error()) } else { nodeLogger.Log(nodeLogger.LogInfo, "xray restarted") } } - cancel() } time.Sleep(time.Second * 2) } diff --git a/controller/controller.go b/controller/controller.go index 0cb53aa..081932a 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -19,54 +19,55 @@ import ( const NodeVersion = "1.0.0" type Service interface { - StopService() + Disconnect() } type Controller struct { backend backend.Backend sessionID uuid.UUID apiPort int + clientIP string lastRequest time.Time stats *common.SystemStatsResponse - aliveCancel context.CancelFunc - statsCancel context.CancelFunc - mu sync.Mutex + cancelFunc context.CancelFunc + mu sync.RWMutex } -func NewController() *Controller { - c := &Controller{ - sessionID: uuid.Nil, - apiPort: tools.FindFreePort(), - } - c.startJobs() - return c +func (c *Controller) Init() { + c.mu.Lock() + defer c.mu.Unlock() + c.sessionID = uuid.Nil + c.apiPort = tools.FindFreePort() + _, c.cancelFunc = context.WithCancel(context.Background()) } func (c *Controller) GetSessionID() uuid.UUID { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.sessionID } -func (c *Controller) Connect(keepAlive uint64) { +func (c *Controller) Connect(ip string, keepAlive uint64) { c.mu.Lock() defer c.mu.Unlock() c.sessionID = uuid.New() c.lastRequest = time.Now() + c.clientIP = ip ctx, cancel := context.WithCancel(context.Background()) - c.aliveCancel = cancel + c.cancelFunc = cancel + go c.recordSystemStats(ctx) if keepAlive > 0 { go c.keepAliveTracker(ctx, time.Duration(keepAlive)*time.Second) } } func (c *Controller) Disconnect() { + c.cancelFunc() + c.mu.Lock() defer c.mu.Unlock() - c.aliveCancel() - if c.backend != nil { c.backend.Shutdown() } @@ -76,7 +77,13 @@ func (c *Controller) Disconnect() { c.apiPort = apiPort c.sessionID = uuid.Nil + c.clientIP = "" +} +func (c *Controller) GetIP() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.clientIP } func (c *Controller) NewRequest() { @@ -105,8 +112,8 @@ func (c *Controller) StartBackend(ctx context.Context, backendType common.Backen } func (c *Controller) GetBackend() backend.Backend { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.backend } @@ -116,9 +123,9 @@ func (c *Controller) keepAliveTracker(ctx context.Context, keepAlive time.Durati case <-ctx.Done(): break default: - c.mu.Lock() + c.mu.RLock() lastRequest := c.lastRequest - c.mu.Unlock() + c.mu.RUnlock() if time.Since(lastRequest) >= keepAlive { log.Println("disconnect automatically due to keep alive timeout") c.Disconnect() @@ -148,8 +155,8 @@ func (c *Controller) recordSystemStats(ctx context.Context) { } func (c *Controller) GetStats() *common.SystemStatsResponse { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.stats } @@ -174,20 +181,3 @@ func (c *Controller) BaseInfoResponse(includeID bool, extra string) *common.Base return response } - -func (c *Controller) startJobs() { - ctx, cancel := context.WithCancel(context.Background()) - c.mu.Lock() - defer c.mu.Unlock() - c.statsCancel = cancel - go c.recordSystemStats(ctx) -} - -func (c *Controller) StopJobs() { - c.mu.Lock() - c.statsCancel() - c.mu.Unlock() - - c.Disconnect() - -} diff --git a/controller/rest/base.go b/controller/rest/base.go index 9b49d74..a04d475 100644 --- a/controller/rest/base.go +++ b/controller/rest/base.go @@ -13,7 +13,7 @@ import ( ) func (s *Service) Base(w http.ResponseWriter, _ *http.Request) { - common.SendProtoResponse(w, s.controller.BaseInfoResponse(false, "")) + common.SendProtoResponse(w, s.BaseInfoResponse(false, "")) } func (s *Service) Start(w http.ResponseWriter, r *http.Request) { @@ -29,26 +29,26 @@ func (s *Service) Start(w http.ResponseWriter, r *http.Request) { return } - if s.controller.GetBackend() != nil { + if s.GetBackend() != nil { log.Println("New connection from ", ip, " core control access was taken away from previous client.") - s.disconnect() + s.Disconnect() } - s.connect(ip, keepAlive) + s.Connect(ip, keepAlive) - log.Println(ip, " connected, Session ID = ", s.controller.GetSessionID()) + log.Println(ip, " connected, Session ID = ", s.GetSessionID()) - if err = s.controller.StartBackend(ctx, backendType); err != nil { + if err = s.StartBackend(ctx, backendType); err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } - common.SendProtoResponse(w, s.controller.BaseInfoResponse(true, "")) + common.SendProtoResponse(w, s.BaseInfoResponse(true, "")) } func (s *Service) Stop(w http.ResponseWriter, _ *http.Request) { - log.Println(s.GetIP(), " disconnected, Session ID = ", s.controller.GetSessionID()) - s.disconnect() + log.Println(s.GetIP(), " disconnected, Session ID = ", s.GetSessionID()) + s.Disconnect() common.SendProtoResponse(w, &common.Empty{}) } diff --git a/controller/rest/log.go b/controller/rest/log.go index 8a97f17..02a7a12 100644 --- a/controller/rest/log.go +++ b/controller/rest/log.go @@ -16,7 +16,7 @@ func (s *Service) GetLogs(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - logChan := s.controller.GetBackend().GetLogs() + logChan := s.GetBackend().GetLogs() for { select { diff --git a/controller/rest/middleware.go b/controller/rest/middleware.go index cd0caf0..766dd3a 100644 --- a/controller/rest/middleware.go +++ b/controller/rest/middleware.go @@ -15,7 +15,7 @@ func (s *Service) checkSessionIDMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // check ip clientIP := s.GetIP() - clientID := s.controller.GetSessionID() + clientID := s.GetSessionID() if clientIP == "" || clientID == uuid.Nil { http.Error(w, "please connect first", http.StatusTooEarly) return @@ -61,7 +61,7 @@ func (s *Service) checkSessionIDMiddleware(next http.Handler) http.Handler { func (s *Service) checkBackendMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - back := s.controller.GetBackend() + back := s.GetBackend() if back == nil { http.Error(w, "backend not initialized", http.StatusInternalServerError) return diff --git a/controller/rest/rest_test.go b/controller/rest/rest_test.go index c1d9d29..fe515d2 100644 --- a/controller/rest/rest_test.go +++ b/controller/rest/rest_test.go @@ -79,7 +79,7 @@ func TestRESTConnection(t *testing.T) { if err != nil { t.Fatalf("Failed to start HTTP listener: %v", err) } - defer s.StopService() + defer s.Disconnect() creds, err := tools.LoadTLSCredentials(sslClientCertFile, sslClientKeyFile, sslCertFile, true) if err != nil { diff --git a/controller/rest/service.go b/controller/rest/service.go index 9837a1f..105a63e 100644 --- a/controller/rest/service.go +++ b/controller/rest/service.go @@ -4,20 +4,15 @@ import ( "context" "crypto/tls" "errors" - "log" - "net/http" - "sync" - "github.com/go-chi/chi/v5" - "github.com/m03ed/gozargah-node/common" "github.com/m03ed/gozargah-node/controller" + "log" + "net/http" ) func NewService() *Service { - s := &Service{ - controller: controller.NewController(), - clientIP: "", - } + s := &Service{} + s.Init() s.setRouter() return s } @@ -59,68 +54,12 @@ func (s *Service) setRouter() { }) }) - s.mu.Lock() - defer s.mu.Unlock() s.Router = router } type Service struct { - Router chi.Router - clientIP string - controller *controller.Controller - mu sync.Mutex -} - -func (s *Service) connect(ip string, keepAlive uint64) { - s.mu.Lock() - defer s.mu.Unlock() - s.clientIP = ip - s.controller.Connect(keepAlive) -} - -func (s *Service) disconnect() { - s.controller.Disconnect() - - s.mu.Lock() - defer s.mu.Unlock() - - s.clientIP = "" -} - -func (s *Service) StopService() { - s.mu.Lock() - defer s.mu.Unlock() - s.controller.StopJobs() -} - -func (s *Service) GetIP() string { - s.mu.Lock() - defer s.mu.Unlock() - return s.clientIP -} - -func (s *Service) response(includeID bool, extra string) *common.BaseInfoResponse { - response := &common.BaseInfoResponse{ - Started: false, - CoreVersion: "", - NodeVersion: controller.NodeVersion, - Extra: extra, - } - - s.mu.Lock() - defer s.mu.Unlock() - - back := s.controller.GetBackend() - if back != nil { - response.Started = back.Started() - response.CoreVersion = back.GetVersion() - } - - if includeID { - response.SessionId = s.controller.GetSessionID().String() - } - - return response + controller.Controller + Router chi.Router } func StartHttpListener(tlsConfig *tls.Config, addr string) (func(ctx context.Context) error, controller.Service, error) { @@ -141,5 +80,5 @@ func StartHttpListener(tlsConfig *tls.Config, addr string) (func(ctx context.Con }() // Return a shutdown function for HTTP server - return httpServer.Shutdown, controller.Service(s), nil + return httpServer.Shutdown, s, nil } diff --git a/controller/rest/stats.go b/controller/rest/stats.go index c4c878f..bb913aa 100644 --- a/controller/rest/stats.go +++ b/controller/rest/stats.go @@ -13,7 +13,7 @@ func (s *Service) GetOutboundsStats(w http.ResponseWriter, r *http.Request) { return } - stats, err := s.controller.GetBackend().GetOutboundsStats(r.Context(), request.GetReset_()) + stats, err := s.GetBackend().GetOutboundsStats(r.Context(), request.GetReset_()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -34,7 +34,7 @@ func (s *Service) GetOutboundStats(w http.ResponseWriter, r *http.Request) { return } - stats, err := s.controller.GetBackend().GetOutboundStats(r.Context(), request.GetName(), request.GetReset_()) + stats, err := s.GetBackend().GetOutboundStats(r.Context(), request.GetName(), request.GetReset_()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -50,7 +50,7 @@ func (s *Service) GetInboundsStats(w http.ResponseWriter, r *http.Request) { return } - stats, err := s.controller.GetBackend().GetInboundsStats(r.Context(), request.GetReset_()) + stats, err := s.GetBackend().GetInboundsStats(r.Context(), request.GetReset_()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -71,7 +71,7 @@ func (s *Service) GetInboundStats(w http.ResponseWriter, r *http.Request) { return } - stats, err := s.controller.GetBackend().GetInboundStats(r.Context(), request.GetName(), request.GetReset_()) + stats, err := s.GetBackend().GetInboundStats(r.Context(), request.GetName(), request.GetReset_()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -87,7 +87,7 @@ func (s *Service) GetUsersStats(w http.ResponseWriter, r *http.Request) { return } - stats, err := s.controller.GetBackend().GetUsersStats(r.Context(), request.GetReset_()) + stats, err := s.GetBackend().GetUsersStats(r.Context(), request.GetReset_()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -108,7 +108,7 @@ func (s *Service) GetUserStats(w http.ResponseWriter, r *http.Request) { return } - stats, err := s.controller.GetBackend().GetUserStats(r.Context(), request.GetName(), request.GetReset_()) + stats, err := s.GetBackend().GetUserStats(r.Context(), request.GetName(), request.GetReset_()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -129,7 +129,7 @@ func (s *Service) GetUserOnlineStat(w http.ResponseWriter, r *http.Request) { return } - stats, err := s.controller.GetBackend().GetStatOnline(r.Context(), request.GetName()) + stats, err := s.GetBackend().GetStatOnline(r.Context(), request.GetName()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -139,7 +139,7 @@ func (s *Service) GetUserOnlineStat(w http.ResponseWriter, r *http.Request) { } func (s *Service) GetBackendStats(w http.ResponseWriter, r *http.Request) { - stats, err := s.controller.GetBackend().GetSysStats(r.Context()) + stats, err := s.GetBackend().GetSysStats(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -149,5 +149,5 @@ func (s *Service) GetBackendStats(w http.ResponseWriter, r *http.Request) { } func (s *Service) GetSystemStats(w http.ResponseWriter, _ *http.Request) { - common.SendProtoResponse(w, s.controller.GetStats()) + common.SendProtoResponse(w, s.GetStats()) } diff --git a/controller/rest/user.go b/controller/rest/user.go index 7c91c2f..5781b10 100644 --- a/controller/rest/user.go +++ b/controller/rest/user.go @@ -28,7 +28,7 @@ func (s *Service) SyncUser(w http.ResponseWriter, r *http.Request) { return } - if err = s.controller.GetBackend().SyncUser(r.Context(), user); err != nil { + if err = s.GetBackend().SyncUser(r.Context(), user); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -56,7 +56,7 @@ func (s *Service) SyncUsers(w http.ResponseWriter, r *http.Request) { return } - if err = s.controller.GetBackend().SyncUsers(r.Context(), users.GetUsers()); err != nil { + if err = s.GetBackend().SyncUsers(r.Context(), users.GetUsers()); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/controller/rpc/base.go b/controller/rpc/base.go index c2dfbe0..46388f7 100644 --- a/controller/rpc/base.go +++ b/controller/rpc/base.go @@ -3,9 +3,12 @@ package rpc import ( "context" "errors" + "net" + "github.com/m03ed/gozargah-node/backend" "github.com/m03ed/gozargah-node/backend/xray" "github.com/m03ed/gozargah-node/common" + "google.golang.org/grpc/peer" ) func (s *Service) Start(ctx context.Context, detail *common.Backend) (*common.BaseInfoResponse, error) { @@ -14,17 +17,34 @@ func (s *Service) Start(ctx context.Context, detail *common.Backend) (*common.Ba return nil, err } - if err = s.controller.StartBackend(ctx, detail.GetType()); err != nil { + if err = s.StartBackend(ctx, detail.GetType()); err != nil { return nil, err } - s.connect(detail.GetKeepAlive()) + clientIP := "" + if p, ok := peer.FromContext(ctx); ok { + // Extract IP address from peer address + if tcpAddr, ok := p.Addr.(*net.TCPAddr); ok { + clientIP = tcpAddr.IP.String() + } else { + // For other address types, extract just the IP without the port + addr := p.Addr.String() + if host, _, err := net.SplitHostPort(addr); err == nil { + clientIP = host + } else { + // If SplitHostPort fails, use the whole address + clientIP = addr + } + } + } + + s.Connect(clientIP, detail.GetKeepAlive()) - return s.controller.BaseInfoResponse(true, ""), nil + return s.BaseInfoResponse(true, ""), nil } func (s *Service) Stop(_ context.Context, _ *common.Empty) (*common.Empty, error) { - s.disconnect() + s.Disconnect() return nil, nil } @@ -45,5 +65,5 @@ func (s *Service) detectBackend(ctx context.Context, detail *common.Backend) (co } func (s *Service) GetBaseInfo(_ context.Context, _ *common.Empty) (*common.BaseInfoResponse, error) { - return s.controller.BaseInfoResponse(false, ""), nil + return s.BaseInfoResponse(false, ""), nil } diff --git a/controller/rpc/log.go b/controller/rpc/log.go index 6042455..3369ebd 100644 --- a/controller/rpc/log.go +++ b/controller/rpc/log.go @@ -8,7 +8,7 @@ import ( ) func (s *Service) GetLogs(_ *common.Empty, stream common.NodeService_GetLogsServer) error { - logChan := s.controller.GetBackend().GetLogs() + logChan := s.GetBackend().GetLogs() for { select { @@ -23,7 +23,7 @@ func (s *Service) GetLogs(_ *common.Empty, stream common.NodeService_GetLogsServ case <-stream.Context().Done(): // Client has disconnected or cancelled the request - return stream.Context().Err() + return nil } } } diff --git a/controller/rpc/middleware.go b/controller/rpc/middleware.go index c0c2bae..3f47212 100644 --- a/controller/rpc/middleware.go +++ b/controller/rpc/middleware.go @@ -23,7 +23,7 @@ func validateSessionID(ctx context.Context, s *Service) error { } // Check session ID - sessionID := s.controller.GetSessionID() + sessionID := s.GetSessionID() if sessionID == uuid.Nil { return status.Errorf(codes.Unauthenticated, "please connect first") } @@ -51,8 +51,8 @@ func validateSessionID(ctx context.Context, s *Service) error { if token != sessionID { return status.Errorf(codes.PermissionDenied, "session ID mismatch") } - - s.controller.NewRequest() + + s.NewRequest() return nil } @@ -90,7 +90,7 @@ func CheckSessionIDStreamMiddleware(s *Service) grpc.StreamServerInterceptor { } func checkBackendStatus(s *Service) error { - back := s.controller.GetBackend() + back := s.GetBackend() if back == nil { return status.Errorf(codes.Internal, "backend not initialized") } diff --git a/controller/rpc/rpc_test.go b/controller/rpc/rpc_test.go index 6623503..13c634c 100644 --- a/controller/rpc/rpc_test.go +++ b/controller/rpc/rpc_test.go @@ -66,7 +66,7 @@ func TestGRPCConnection(t *testing.T) { } shutdownFunc, s, err := StartGRPCListener(tlsConfig, addr) - defer s.StopService() + defer s.Disconnect() if err != nil { t.Fatal(err) } diff --git a/controller/rpc/service.go b/controller/rpc/service.go index 7d08e32..f2bb131 100644 --- a/controller/rpc/service.go +++ b/controller/rpc/service.go @@ -4,41 +4,25 @@ import ( "context" "crypto/tls" "fmt" + "github.com/m03ed/gozargah-node/common" + "github.com/m03ed/gozargah-node/controller" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "log" "net" - "sync" - - "github.com/m03ed/gozargah-node/common" - "github.com/m03ed/gozargah-node/controller" ) type Service struct { common.UnimplementedNodeServiceServer - controller *controller.Controller - mu sync.Mutex + controller.Controller } func NewService() *Service { - s := &Service{controller: controller.NewController()} + s := &Service{} + s.Init() return s } -func (s *Service) StopService() { - s.mu.Lock() - defer s.mu.Unlock() - s.controller.StopJobs() -} - -func (s *Service) connect(keepAlive uint64) { - s.controller.Connect(keepAlive) -} - -func (s *Service) disconnect() { - s.controller.Disconnect() -} - func StartGRPCListener(tlsConfig *tls.Config, addr string) (func(ctx context.Context) error, controller.Service, error) { s := NewService() @@ -84,5 +68,5 @@ func StartGRPCListener(tlsConfig *tls.Config, addr string) (func(ctx context.Con grpcServer.Stop() // Force stop if graceful stop times out return ctx.Err() } - }, controller.Service(s), nil + }, s, nil } diff --git a/controller/rpc/stats.go b/controller/rpc/stats.go index 40e3738..cc67c6f 100644 --- a/controller/rpc/stats.go +++ b/controller/rpc/stats.go @@ -8,49 +8,49 @@ import ( ) func (s *Service) GetOutboundsStats(ctx context.Context, request *common.StatRequest) (*common.StatResponse, error) { - return s.controller.GetBackend().GetOutboundsStats(ctx, request.GetReset_()) + return s.GetBackend().GetOutboundsStats(ctx, request.GetReset_()) } func (s *Service) GetOutboundStats(ctx context.Context, request *common.StatRequest) (*common.StatResponse, error) { if request.GetName() == "" { return nil, errors.New("name is required") } - return s.controller.GetBackend().GetOutboundStats(ctx, request.GetName(), request.GetReset_()) + return s.GetBackend().GetOutboundStats(ctx, request.GetName(), request.GetReset_()) } func (s *Service) GetInboundsStats(ctx context.Context, request *common.StatRequest) (*common.StatResponse, error) { - return s.controller.GetBackend().GetInboundsStats(ctx, request.GetReset_()) + return s.GetBackend().GetInboundsStats(ctx, request.GetReset_()) } func (s *Service) GetInboundStats(ctx context.Context, request *common.StatRequest) (*common.StatResponse, error) { if request.GetName() == "" { return nil, errors.New("name is required") } - return s.controller.GetBackend().GetInboundStats(ctx, request.GetName(), request.GetReset_()) + return s.GetBackend().GetInboundStats(ctx, request.GetName(), request.GetReset_()) } func (s *Service) GetUsersStats(ctx context.Context, request *common.StatRequest) (*common.StatResponse, error) { - return s.controller.GetBackend().GetUsersStats(ctx, request.GetReset_()) + return s.GetBackend().GetUsersStats(ctx, request.GetReset_()) } func (s *Service) GetUserStats(ctx context.Context, request *common.StatRequest) (*common.StatResponse, error) { if request.GetName() == "" { return nil, errors.New("name is required") } - return s.controller.GetBackend().GetUserStats(ctx, request.GetName(), request.GetReset_()) + return s.GetBackend().GetUserStats(ctx, request.GetName(), request.GetReset_()) } func (s *Service) GetUserOnlineStats(ctx context.Context, request *common.StatRequest) (*common.OnlineStatResponse, error) { if request.GetName() == "" { return nil, errors.New("name is required") } - return s.controller.GetBackend().GetStatOnline(ctx, request.GetName()) + return s.GetBackend().GetStatOnline(ctx, request.GetName()) } func (s *Service) GetBackendStats(ctx context.Context, _ *common.Empty) (*common.BackendStatsResponse, error) { - return s.controller.GetBackend().GetSysStats(ctx) + return s.GetBackend().GetSysStats(ctx) } func (s *Service) GetSystemStats(_ context.Context, _ *common.Empty) (*common.SystemStatsResponse, error) { - return s.controller.GetStats(), nil + return s.GetStats(), nil } diff --git a/controller/rpc/user.go b/controller/rpc/user.go index 8d83bd2..2f7185e 100644 --- a/controller/rpc/user.go +++ b/controller/rpc/user.go @@ -3,8 +3,6 @@ package rpc import ( "context" "errors" - "io" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -15,25 +13,22 @@ import ( func (s *Service) SyncUser(stream grpc.ClientStreamingServer[common.User, common.Empty]) error { for { user, err := stream.Recv() - if err == io.EOF { - return stream.SendAndClose(&common.Empty{}) - } if err != nil { - return status.Errorf(codes.Internal, "failed to receive user: %v", err) + return stream.SendAndClose(&common.Empty{}) } if user.GetEmail() == "" { return errors.New("email is required") } - if err = s.controller.GetBackend().SyncUser(stream.Context(), user); err != nil { + if err = s.GetBackend().SyncUser(stream.Context(), user); err != nil { return status.Errorf(codes.Internal, "failed to update user: %v", err) } } } func (s *Service) SyncUsers(ctx context.Context, users *common.Users) (*common.Empty, error) { - if err := s.controller.GetBackend().SyncUsers(ctx, users.GetUsers()); err != nil { + if err := s.GetBackend().SyncUsers(ctx, users.GetUsers()); err != nil { return nil, err } diff --git a/main.go b/main.go index f258fb1..168af63 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,7 @@ func main() { shutdownFunc, service, err = rpc.StartGRPCListener(tlsConfig, addr) } - defer service.StopService() + defer service.Disconnect() stopChan := make(chan os.Signal, 1) signal.Notify(stopChan, os.Interrupt, syscall.SIGTERM)