diff --git a/go.mod b/go.mod index 9bf8c14..6bd499b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25.5 require ( github.com/afritzler/protoequal v0.1.10 + github.com/alitto/pond v1.9.2 github.com/gofrs/flock v0.13.0 github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.4 @@ -46,6 +47,7 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect + go.uber.org/mock v0.6.0 // indirect go.uber.org/multierr v1.10.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect diff --git a/go.sum b/go.sum index d0ea10c..497068e 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1 github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/afritzler/protoequal v0.1.10 h1:HRWukWQ6Q0msWv0BwArWJKcjZ1Hz2qFoYlKb1VnzkTg= github.com/afritzler/protoequal v0.1.10/go.mod h1:65ALCt5ghpaRzoWohyRnx88X7o5y6cQwJmOb9yzdheg= +github.com/alitto/pond v1.9.2 h1:9Qb75z/scEZVCoSU+osVmQ0I0JOeLfdTDafrbcJ8CLs= +github.com/alitto/pond v1.9.2/go.mod h1:xQn3P/sHTYcU/1BR3i86IGIrilcrGC2LiS+E2+CJWsI= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -208,6 +210,8 @@ go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= diff --git a/internal/app/app.go b/internal/app/app.go index f209702..5d8bf83 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -265,7 +265,7 @@ func NewApp( pb.RegisterGetQueryInfoServer(s, &grpc.GetQueryInfoServer{Logger: baseApp.L(), MaxMessageSize: int(config.MaxOuterMessageSize), RQStorage: backgroundStorage.RQStorage}) pb.RegisterAgentControlServer(s, &grpc.AgentControlServer{Logger: baseApp.L(), RQStorage: backgroundStorage.RQStorage}) - getMasterInfo := grpc.NewGetMasterInfoServer(config.ClusterID, baseApp.L(), statActivityLister, int(config.MaxOuterMessageSize), backgroundStorage) + getMasterInfo := grpc.NewGetMasterInfoServer(config.ClusterID, baseApp.L(), int(config.MaxOuterMessageSize), backgroundStorage) actionInfo := &grpc.ActionsServer{ClusterID: config.ClusterID, Logger: baseApp.L(), Timeout: 5 * time.Minute, BackgroundStorage: backgroundStorage} pbm.RegisterGetGPInfoServer(s, getMasterInfo) @@ -382,12 +382,12 @@ func Run(ctx context.Context, configFile string) error { metrics.YagpccMetrics.ExecutingQueryLatencies.AssignQueryGetter(rqStorage.GetQueriesStartTime) aggStorage := storage.NewConfiguredAggregatedStorage(logger, cfg) sessionsStorage := gp.NewSessionsStorage(rqStorage) - backgroundStorage := master.NewBackgroundStorage(logger, sessionsStorage, rqStorage, aggStorage) masterConnection := gp.NewConnection(baseApp.L(), &cfg.MasterConnection, nil) masterSentinel := master_sentinel.NewSentinel(baseApp.L(), masterConnection) statActivityLister := stat_activity.NewLister(baseApp.L(), masterConnection) + backgroundStorage := master.NewBackgroundStorage(logger, sessionsStorage, rqStorage, aggStorage, statActivityLister) agentApp, err := NewApp(baseApp, cfg, statActivityLister, backgroundStorage) if err != nil { @@ -515,7 +515,7 @@ func Run(ctx context.Context, configFile string) error { logger.Infof("Starting master background tasks") ctxC, ctxF := context.WithCancel(ctx) defer ctxF() - err = master.InitBG(ctxC, logger, masterSentinel, statActivityLister, cfg, backgroundStorage) + err = master.InitBG(ctxC, logger, masterSentinel, cfg, backgroundStorage) if err != nil { logger.Fatal(err.Error()) return err diff --git a/internal/grpc/actions_test.go b/internal/grpc/actions_test.go index 4d32962..c88db77 100644 --- a/internal/grpc/actions_test.go +++ b/internal/grpc/actions_test.go @@ -23,7 +23,7 @@ func newTestActionsServer(t *testing.T) *ActionsServer { rq := storage.NewRunningQueriesStorage() sessStorage := gp.NewSessionsStorage(rq) agg := storage.NewAggregatedStorage(z) - bg := master.NewBackgroundStorage(z, sessStorage, rq, agg) + bg := master.NewBackgroundStorage(z, sessStorage, rq, agg, nil) return &ActionsServer{ Logger: z, Timeout: 5 * time.Second, diff --git a/internal/grpc/deps.go b/internal/grpc/deps.go deleted file mode 100644 index d53fa65..0000000 --- a/internal/grpc/deps.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:generate mockgen -source=deps.go -package=grpc_test -mock_names statActivityLister=MockStatActivityLister -destination mocks_test.go - -package grpc - -import ( - "context" - - "github.com/open-gpdb/yagpcc/internal/gp" -) - -type statActivityLister interface { - Start(ctx context.Context) error - Stop() - List(ctx context.Context) ([]*gp.GpStatActivity, error) -} diff --git a/internal/grpc/get_master_info.go b/internal/grpc/get_master_info.go index e2a8e7d..667b8a7 100644 --- a/internal/grpc/get_master_info.go +++ b/internal/grpc/get_master_info.go @@ -24,20 +24,18 @@ import ( type GetMasterInfoServer struct { pbm.UnimplementedGetGPInfoServer - clusterID string - logger *zap.SugaredLogger - statActivityLister statActivityLister - maxMessageSize int - backgroundStorage *master.BackgroundStorage + clusterID string + logger *zap.SugaredLogger + maxMessageSize int + backgroundStorage *master.BackgroundStorage } -func NewGetMasterInfoServer(clusterID string, logger *zap.SugaredLogger, statActivityLister statActivityLister, maxMessageSize int, backgroundStorage *master.BackgroundStorage) *GetMasterInfoServer { +func NewGetMasterInfoServer(clusterID string, logger *zap.SugaredLogger, maxMessageSize int, backgroundStorage *master.BackgroundStorage) *GetMasterInfoServer { return &GetMasterInfoServer{ - clusterID: clusterID, - logger: logger, - statActivityLister: statActivityLister, - maxMessageSize: maxMessageSize, - backgroundStorage: backgroundStorage, + clusterID: clusterID, + logger: logger, + maxMessageSize: maxMessageSize, + backgroundStorage: backgroundStorage, } } @@ -2557,7 +2555,7 @@ func (s *GetMasterInfoServer) GetGPSessions(ctx context.Context, in *pbm.GetGPSe queryType = pbm.RunningQueryType_RQT_TOP } // refresh list of sessions - err := s.backgroundStorage.TryRefreshSessionsFromGP(ctx, s.statActivityLister, true) + err := s.backgroundStorage.TryRefreshSessionsFromGP(ctx, true) if err != nil { s.logger.Errorf("error while refreshing session list: %v", err) } diff --git a/internal/grpc/get_master_info_test.go b/internal/grpc/get_master_info_test.go index c824952..ae83c7e 100644 --- a/internal/grpc/get_master_info_test.go +++ b/internal/grpc/get_master_info_test.go @@ -6,9 +6,9 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/timestamppb" pbm "github.com/open-gpdb/yagpcc/api/proto/agent_master" diff --git a/internal/grpc/grpc_test.go b/internal/grpc/grpc_test.go new file mode 100644 index 0000000..ef7fe83 --- /dev/null +++ b/internal/grpc/grpc_test.go @@ -0,0 +1,17 @@ +//go:generate mockgen -source=grpc_test.go -package=grpc_test -mock_names statActivityLister=MockStatActivityLister -destination mocks_test.go + +package grpc_test + +import ( + "context" + + "github.com/open-gpdb/yagpcc/internal/gp" + "github.com/open-gpdb/yagpcc/internal/gp/stat_activity" +) + +type statActivityLister interface { //nolint:unused // used by go:generate mockgen + Start(ctx context.Context) error + Stop() + List(ctx context.Context) ([]*gp.GpStatActivity, error) + ListAllSessions(context.Context) ([]stat_activity.SessionPid, error) +} diff --git a/internal/grpc/mocks_test.go b/internal/grpc/mocks_test.go index dfb6a84..e150ee1 100644 --- a/internal/grpc/mocks_test.go +++ b/internal/grpc/mocks_test.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: deps.go +// Source: grpc_test.go +// +// Generated by this command: +// +// mockgen -source=grpc_test.go -package=grpc_test -mock_names statActivityLister=MockStatActivityLister -destination mocks_test.go +// // Package grpc_test is a generated GoMock package. package grpc_test @@ -8,34 +13,66 @@ import ( context "context" reflect "reflect" - gomock "github.com/golang/mock/gomock" gp "github.com/open-gpdb/yagpcc/internal/gp" + stat_activity "github.com/open-gpdb/yagpcc/internal/gp/stat_activity" + gomock "go.uber.org/mock/gomock" ) -// MockStatActivityLister is a mock of statActivityLister interface +// MockStatActivityLister is a mock of statActivityLister interface. type MockStatActivityLister struct { ctrl *gomock.Controller recorder *MockStatActivityListerMockRecorder + isgomock struct{} } -// MockStatActivityListerMockRecorder is the mock recorder for MockStatActivityLister +// MockStatActivityListerMockRecorder is the mock recorder for MockStatActivityLister. type MockStatActivityListerMockRecorder struct { mock *MockStatActivityLister } -// NewMockStatActivityLister creates a new mock instance +// NewMockStatActivityLister creates a new mock instance. func NewMockStatActivityLister(ctrl *gomock.Controller) *MockStatActivityLister { mock := &MockStatActivityLister{ctrl: ctrl} mock.recorder = &MockStatActivityListerMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockStatActivityLister) EXPECT() *MockStatActivityListerMockRecorder { return m.recorder } -// Start mocks base method +// List mocks base method. +func (m *MockStatActivityLister) List(ctx context.Context) ([]*gp.GpStatActivity, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", ctx) + ret0, _ := ret[0].([]*gp.GpStatActivity) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockStatActivityListerMockRecorder) List(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStatActivityLister)(nil).List), ctx) +} + +// ListAllSessions mocks base method. +func (m *MockStatActivityLister) ListAllSessions(arg0 context.Context) ([]stat_activity.SessionPid, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAllSessions", arg0) + ret0, _ := ret[0].([]stat_activity.SessionPid) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAllSessions indicates an expected call of ListAllSessions. +func (mr *MockStatActivityListerMockRecorder) ListAllSessions(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllSessions", reflect.TypeOf((*MockStatActivityLister)(nil).ListAllSessions), arg0) +} + +// Start mocks base method. func (m *MockStatActivityLister) Start(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Start", ctx) @@ -43,35 +80,20 @@ func (m *MockStatActivityLister) Start(ctx context.Context) error { return ret0 } -// Start indicates an expected call of Start -func (mr *MockStatActivityListerMockRecorder) Start(ctx interface{}) *gomock.Call { +// Start indicates an expected call of Start. +func (mr *MockStatActivityListerMockRecorder) Start(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockStatActivityLister)(nil).Start), ctx) } -// Stop mocks base method +// Stop mocks base method. func (m *MockStatActivityLister) Stop() { m.ctrl.T.Helper() m.ctrl.Call(m, "Stop") } -// Stop indicates an expected call of Stop +// Stop indicates an expected call of Stop. func (mr *MockStatActivityListerMockRecorder) Stop() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockStatActivityLister)(nil).Stop)) } - -// List mocks base method -func (m *MockStatActivityLister) List(ctx context.Context) ([]*gp.GpStatActivity, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List", ctx) - ret0, _ := ret[0].([]*gp.GpStatActivity) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// List indicates an expected call of List -func (mr *MockStatActivityListerMockRecorder) List(ctx interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStatActivityLister)(nil).List), ctx) -} diff --git a/internal/grpc/set_get_query_info_parallel_test.go b/internal/grpc/set_get_query_info_parallel_test.go index d79f2b1..c329d7a 100644 --- a/internal/grpc/set_get_query_info_parallel_test.go +++ b/internal/grpc/set_get_query_info_parallel_test.go @@ -30,7 +30,7 @@ func TestParallelSetGet(t *testing.T) { rqStorage := storage.NewRunningQueriesStorage() aggStorage := storage.NewAggregatedStorage(zLogger) sessStorage := gp.NewSessionsStorage(rqStorage) - backgroundStorage := master.NewBackgroundStorage(zLogger, sessStorage, rqStorage, aggStorage) + backgroundStorage := master.NewBackgroundStorage(zLogger, sessStorage, rqStorage, aggStorage, nil) tests := []struct { name string @@ -48,7 +48,7 @@ func TestParallelSetGet(t *testing.T) { {name: "test Get Queries", isSet: false, paramName: "ALL", ssid: 1, value: 1, cnt: 80, sleep: 0.01}, {name: "test Get Query1", isSet: false, paramName: "QUERY", ssid: 1, value: 1, cnt: 10000, sleep: 0}, } - dial := setupGRPCDialer(t, nil, backgroundStorage) + dial := setupGRPCDialer(t, backgroundStorage) ctx := context.Background() connectTimeout := 5 * time.Second diff --git a/internal/grpc/testutils_test.go b/internal/grpc/testutils_test.go index 7863138..fdc6d8b 100644 --- a/internal/grpc/testutils_test.go +++ b/internal/grpc/testutils_test.go @@ -8,8 +8,8 @@ import ( "sort" "testing" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" gogrpc "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" @@ -65,7 +65,7 @@ func assertQueriesInfoResponseEqual(t *testing.T, expected *pb.GetQueriesInfoRes return utils.AssertProtoMessagesEqual(t, normalize(expected), normalize(actual)) } -func setupGRPCDialer(t *testing.T, sessionMocker *MockStatActivityLister, backgroundStorage *master.BackgroundStorage) func(context.Context, string) (net.Conn, error) { +func setupGRPCDialer(t *testing.T, backgroundStorage *master.BackgroundStorage) func(context.Context, string) (net.Conn, error) { cfg, err := config.DefaultConfig() require.NoError(t, err, "error getting default config") @@ -88,7 +88,7 @@ func setupGRPCDialer(t *testing.T, sessionMocker *MockStatActivityLister, backgr pb.RegisterSetQueryInfoServer(server, &grpc.SetQueryInfoServer{Logger: zLogger, UpdateSessionMetrics: true, RQStorage: backgroundStorage.RQStorage, SessionsStorage: backgroundStorage.SessionStorage}) pb.RegisterGetQueryInfoServer(server, &grpc.GetQueryInfoServer{Logger: zLogger, MaxMessageSize: 100 * 1024 * 1024, RQStorage: backgroundStorage.RQStorage}) pb.RegisterAgentControlServer(server, &grpc.AgentControlServer{Logger: zLogger, RQStorage: backgroundStorage.RQStorage}) - pbm.RegisterGetGPInfoServer(server, grpc.NewGetMasterInfoServer("test", zLogger, sessionMocker, 100*1024*1024, backgroundStorage)) + pbm.RegisterGetGPInfoServer(server, grpc.NewGetMasterInfoServer("test", zLogger, 100*1024*1024, backgroundStorage)) go func() { if err := server.Serve(listener); err != nil { @@ -129,6 +129,7 @@ func setupGRPCClientSet(t *testing.T, sessionMocker *MockStatActivityLister) (*g ctrl := gomock.NewController(t) sessionMocker = NewMockStatActivityLister(ctrl) sessionMocker.EXPECT().List(gomock.Any()).AnyTimes() + sessionMocker.EXPECT().ListAllSessions(gomock.Any()).AnyTimes() } file, err := os.Create("trace.log") @@ -137,11 +138,11 @@ func setupGRPCClientSet(t *testing.T, sessionMocker *MockStatActivityLister) (*g rqStorage := storage.NewRunningQueriesStorage() sessStorage := gp.NewSessionsStorage(rqStorage) aggStorage := storage.NewAggregatedStorage(zLogger) - backgroundStorage := master.NewBackgroundStorage(zLogger, sessStorage, rqStorage, aggStorage) + backgroundStorage := master.NewBackgroundStorage(zLogger, sessStorage, rqStorage, aggStorage, sessionMocker) conn, err := gogrpc.NewClient( "localhost", - gogrpc.WithContextDialer(setupGRPCDialer(t, sessionMocker, backgroundStorage)), + gogrpc.WithContextDialer(setupGRPCDialer(t, backgroundStorage)), gogrpc.WithTransportCredentials(insecure.NewCredentials()), ) diff --git a/internal/master/background.go b/internal/master/background.go index 56d2b74..8fc4b70 100644 --- a/internal/master/background.go +++ b/internal/master/background.go @@ -3,7 +3,6 @@ package master import ( "context" "fmt" - "net" "os" "sort" "sync" @@ -11,8 +10,6 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials/insecure" pbm "github.com/open-gpdb/yagpcc/api/proto/agent_master" pb "github.com/open-gpdb/yagpcc/api/proto/agent_segment" @@ -38,67 +35,20 @@ type ( segmentMap map[string]*segmentAddr BackgroundStorage struct { - l *zap.SugaredLogger - SessionStorage *gp.SessionsStorage - AggStorage *storage.AggregatedStorage - RQStorage *storage.RunningQueriesStorage + l *zap.SugaredLogger + SessionStorage *gp.SessionsStorage + AggStorage *storage.AggregatedStorage + RQStorage *storage.RunningQueriesStorage + statActivityLister statActivityLister } ) var ( - segChan chan segmentAddr - segConnections map[string]*grpc.ClientConn = make(map[string]*grpc.ClientConn) - segConnectionLock sync.Mutex - segCount int - segCountLock sync.Mutex + segChan chan segmentAddr + segCount int + segCountLock sync.Mutex ) -func getSegAddr(hostname string, portn uint32) string { - return fmt.Sprintf("%s:%d", hostname, portn) -} - -func getGrpcClientConnection(ctx context.Context, hostname string, portn uint32, segConnectTimeoutSec float64) (*grpc.ClientConn, error) { - var err error - segConnectionLock.Lock() - defer segConnectionLock.Unlock() - conn, ok := segConnections[hostname] - if ok { - if conn.GetState() == connectivity.Ready { - return conn, nil - } - } - connectTimeout := time.Second * time.Duration(segConnectTimeoutSec) - if portn > 0 { - conn, err = grpc.NewClient( - getSegAddr(hostname, portn), - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithConnectParams(grpc.ConnectParams{ - MinConnectTimeout: connectTimeout, - }), - ) - if err != nil { - return nil, err - } - } else { - conn, err = grpc.NewClient( - hostname, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - var d net.Dialer - return d.DialContext(ctx, "unix", addr) - }), - grpc.WithConnectParams(grpc.ConnectParams{ - MinConnectTimeout: connectTimeout, - }), - ) - if err != nil { - return nil, err - } - } - segConnections[hostname] = conn - return conn, nil -} - func (bs *BackgroundStorage) SendSegmentRefreshMessages(ctx context.Context, pullRateSec float64, configCacheDurability time.Duration, portn uint32, customSegmentList *config.SegmentList) error { durationBetweenLoop := time.Second * time.Duration(pullRateSec) @@ -358,10 +308,9 @@ func queryCompleted(qKey *storage.QueryKey, qVal *storage.RunningQuery, segmentG func (bs *BackgroundStorage) TryRefreshSessionsFromGP( ctx context.Context, - statActivityLister statActivityLister, clearDeletedSessions bool, ) error { - newSesList, err := statActivityLister.List(ctx) + newSesList, err := bs.statActivityLister.List(ctx) if err != nil { return fmt.Errorf("error getting sessions: %w", err) } @@ -458,10 +407,7 @@ func (bs *BackgroundStorage) ClearCompletedQueries(ctx context.Context, return nil } -func (bs *BackgroundStorage) RefreshSessions(ctx context.Context, - statActivityLister statActivityLister, - sessionRefreshInterval time.Duration, - clearDeletedSessions bool) error { +func (bs *BackgroundStorage) RefreshSessions(ctx context.Context, sessionRefreshInterval time.Duration, clearDeletedSessions bool) error { for { currTime := time.Now() nextTime := currTime.Truncate(sessionRefreshInterval).Add(sessionRefreshInterval) @@ -471,7 +417,7 @@ func (bs *BackgroundStorage) RefreshSessions(ctx context.Context, return fmt.Errorf("done context with %v", ctx.Err()) default: bs.l.Info("Refresh session List") - err := bs.TryRefreshSessionsFromGP(ctx, statActivityLister, clearDeletedSessions) + err := bs.TryRefreshSessionsFromGP(ctx, clearDeletedSessions) if err != nil { bs.l.Errorf("fail to refresh session list %v", err) return err @@ -528,12 +474,13 @@ func InitConnection(ctx context.Context, l *zap.SugaredLogger, cfg *config.Confi return nil } -func NewBackgroundStorage(l *zap.SugaredLogger, sessionStorage *gp.SessionsStorage, rqStorage *storage.RunningQueriesStorage, aggStorage *storage.AggregatedStorage) *BackgroundStorage { +func NewBackgroundStorage(l *zap.SugaredLogger, sessionStorage *gp.SessionsStorage, rqStorage *storage.RunningQueriesStorage, aggStorage *storage.AggregatedStorage, sActivityLister statActivityLister) *BackgroundStorage { return &BackgroundStorage{ - l: l, - SessionStorage: sessionStorage, - AggStorage: aggStorage, - RQStorage: rqStorage, + l: l, + SessionStorage: sessionStorage, + AggStorage: aggStorage, + RQStorage: rqStorage, + statActivityLister: sActivityLister, } } @@ -541,7 +488,6 @@ func InitBG( ctx context.Context, l *zap.SugaredLogger, masterSentinel masterSentinel, - statActivityLister statActivityLister, cfg *config.Config, backgroundStorage *BackgroundStorage, ) error { @@ -568,8 +514,11 @@ func InitBG( return nil }) - if err = statActivityLister.Start(ctx); err != nil { - return fmt.Errorf("error starting stat activity lister") + if backgroundStorage.statActivityLister == nil { + return fmt.Errorf("stat activity lister is nil") + } + if err = backgroundStorage.statActivityLister.Start(ctx); err != nil { + return fmt.Errorf("error starting stat activity lister: %w", err) } errG.Go(func() error { @@ -602,7 +551,7 @@ func InitBG( }, ) errG.Go(func() error { - err := backgroundStorage.RefreshSessions(ctxI, statActivityLister, cfg.SessionRefreshInterval, cfg.ClearDeletedSessions) + err := backgroundStorage.RefreshSessions(ctxI, cfg.SessionRefreshInterval, cfg.ClearDeletedSessions) l.Errorf("got %v refresh session and queries", err) return err }, @@ -615,7 +564,7 @@ func InitBG( ) err = errG.Wait() if err != nil { - statActivityLister.Stop() + backgroundStorage.statActivityLister.Stop() l.Errorf("Fail in background precesses - done work with %v", err) return err } diff --git a/internal/master/deps.go b/internal/master/deps.go index c56b77e..d0db59c 100644 --- a/internal/master/deps.go +++ b/internal/master/deps.go @@ -4,12 +4,14 @@ import ( "context" "github.com/open-gpdb/yagpcc/internal/gp" + "github.com/open-gpdb/yagpcc/internal/gp/stat_activity" ) type statActivityLister interface { Start(ctx context.Context) error Stop() List(ctx context.Context) ([]*gp.GpStatActivity, error) + ListAllSessions(ctx context.Context) ([]stat_activity.SessionPid, error) } type masterSentinel interface { diff --git a/internal/master/procfs.go b/internal/master/procfs.go new file mode 100644 index 0000000..849af0d --- /dev/null +++ b/internal/master/procfs.go @@ -0,0 +1,109 @@ +package master + +import ( + "context" + "fmt" + "time" + + "golang.org/x/sync/errgroup" + + "google.golang.org/grpc" + + "github.com/open-gpdb/yagpcc/internal/gp/stat_activity" + "github.com/open-gpdb/yagpcc/internal/storage" + + pb "github.com/open-gpdb/yagpcc/api/proto/agent_segment" +) + +const ( + jobsPerQuery = 1000 +) + +type ( + hostJobMap = map[string][]stat_activity.SessionPid +) + +func (bs *BackgroundStorage) getJobsMap(sessions []stat_activity.SessionPid) hostJobMap { + hostJobs := make(hostJobMap) + // make work for each host + for _, process := range sessions { + segHost := storage.GetHostnameForSegindex(int32(process.GpSegmentId)) + jobList, ok := hostJobs[segHost] + if !ok { + jobList = make([]stat_activity.SessionPid, 0, 10) + } + jobList = append(jobList, stat_activity.SessionPid{ + GpSegmentId: process.GpSegmentId, + Pid: process.Pid, + SessId: process.SessId, + }) + hostJobs[segHost] = jobList + } + return hostJobs +} + +func (bs *BackgroundStorage) processProcfsRequests(ctx context.Context, hostname string, portn uint32, gatherTimeout time.Duration, maxMsgSize int, reqs []stat_activity.SessionPid) error { + grpcConn, err := getGrpcClientConnection(ctx, hostname, portn, gatherTimeout.Seconds()) + if err != nil { + return fmt.Errorf("grpc client connection error: %v", err) + } + cGet := pb.NewGetQueryInfoClient(grpcConn) + ctxTimeout, ctxCancel := context.WithTimeout(ctx, gatherTimeout) + defer ctxCancel() + maxSizeOption := grpc.MaxCallRecvMsgSize(maxMsgSize) + msgReq := &pb.GetPidProcInfoReq{ + SegmentProcess: make([]*pb.SegmentProcess, 0, 10), + } + for _, req := range reqs { + select { + case <-ctx.Done(): + return ctx.Err() + default: + msgReq.SegmentProcess = append(msgReq.SegmentProcess, &pb.SegmentProcess{ + GpSegmentId: int64(req.GpSegmentId), + Pid: int64(req.Pid), + SessId: int64(req.SessId), + }) + if len(msgReq.SegmentProcess) >= jobsPerQuery { + _, errGet := cGet.GetPidProcStat(ctxTimeout, msgReq, maxSizeOption) + if errGet != nil { + return fmt.Errorf("grpc get pid proc stat error: %v", errGet) + } + msgReq.SegmentProcess = make([]*pb.SegmentProcess, 0, 10) + } + } + } + if len(msgReq.SegmentProcess) > 0 { + _, errGet := cGet.GetPidProcStat(ctxTimeout, msgReq, maxSizeOption) + if errGet != nil { + return fmt.Errorf("grpc get pid proc stat error: %v", errGet) + } + + } + return nil +} + +func (bs *BackgroundStorage) GatherProcfsStat(ctx context.Context, nPullers int, portn uint32, gatherTimeout time.Duration, maxMsgSize int) error { + if nPullers <= 0 { + return fmt.Errorf("nPullers must be greater than 0, got %d", nPullers) + } + bs.l.Debug("GatherProcfsStat") + sessions, err := bs.statActivityLister.ListAllSessions(ctx) + if err != nil { + return fmt.Errorf("error listing sessions pids: %v", err) + } + hostJobs := bs.getJobsMap(sessions) + + ctxT, ctxTC := context.WithTimeout(ctx, gatherTimeout) + defer ctxTC() + + g, ctxG := errgroup.WithContext(ctxT) + + for hostname, procfsProcesses := range hostJobs { + g.Go(func() error { + return bs.processProcfsRequests(ctxG, hostname, portn, gatherTimeout, maxMsgSize, procfsProcesses) + }) + } + + return g.Wait() +} diff --git a/internal/master/procfs_test.go b/internal/master/procfs_test.go new file mode 100644 index 0000000..0fa68ee --- /dev/null +++ b/internal/master/procfs_test.go @@ -0,0 +1,502 @@ +package master + +import ( + "context" + "fmt" + "log" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + pb "github.com/open-gpdb/yagpcc/api/proto/agent_segment" + "github.com/open-gpdb/yagpcc/internal/gp" + "github.com/open-gpdb/yagpcc/internal/gp/stat_activity" + "github.com/open-gpdb/yagpcc/internal/storage" +) + +// --- mock statActivityLister --- + +type mockStatActivityLister struct { + sessions []stat_activity.SessionPid + sessionsErr error + listCalled bool +} + +func (m *mockStatActivityLister) Start(context.Context) error { return nil } +func (m *mockStatActivityLister) Stop() {} +func (m *mockStatActivityLister) List(context.Context) ([]*gp.GpStatActivity, error) { + return nil, nil +} +func (m *mockStatActivityLister) ListAllSessions(context.Context) ([]stat_activity.SessionPid, error) { + m.listCalled = true + return m.sessions, m.sessionsErr +} + +// --- fake gRPC server for GetPidProcStat --- + +type fakeProcStatServer struct { + pb.UnimplementedGetQueryInfoServer + mu sync.Mutex + called bool + lastReq *pb.GetPidProcInfoReq +} + +func (s *fakeProcStatServer) GetPidProcStat(_ context.Context, req *pb.GetPidProcInfoReq) (*pb.GetPidProcInfoResponse, error) { + s.mu.Lock() + s.called = true + s.lastReq = req + s.mu.Unlock() + return &pb.GetPidProcInfoResponse{}, nil +} + +func (s *fakeProcStatServer) snapshot() (bool, *pb.GetPidProcInfoReq) { + s.mu.Lock() + defer s.mu.Unlock() + return s.called, s.lastReq +} + +func (s *fakeProcStatServer) GetMetricQueries(_ context.Context, _ *pb.GetQueriesInfoReq) (*pb.GetQueriesInfoResponse, error) { + return &pb.GetQueriesInfoResponse{}, nil +} + +type failingProcStatServer struct { + pb.UnimplementedGetQueryInfoServer +} + +func (s *failingProcStatServer) GetPidProcStat(context.Context, *pb.GetPidProcInfoReq) (*pb.GetPidProcInfoResponse, error) { + return nil, fmt.Errorf("simulated gRPC error") +} + +func (s *failingProcStatServer) GetMetricQueries(context.Context, *pb.GetQueriesInfoReq) (*pb.GetQueriesInfoResponse, error) { + return &pb.GetQueriesInfoResponse{}, nil +} + +// setupBufconnServer creates a gRPC server on a bufconn listener, registers +// the provided GetQueryInfo service implementation, starts serving, and +// returns the listener. +func setupBufconnServer(t *testing.T, srv pb.GetQueryInfoServer) *bufconn.Listener { + t.Helper() + lis := bufconn.Listen(1024 * 1024) + s := grpc.NewServer() + pb.RegisterGetQueryInfoServer(s, srv) + go func() { + if err := s.Serve(lis); err != nil { + log.Printf("bufconn server exited: %v", err) + } + }() + t.Cleanup(func() { s.Stop() }) + return lis +} + +func dialBufconn(t *testing.T, lis *bufconn.Listener) *grpc.ClientConn { + t.Helper() + conn, err := grpc.NewClient( + "passthrough:///bufconn", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return lis.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + return conn +} + +func newTestLogger() *zap.SugaredLogger { + cfg := zap.NewDevelopmentConfig() + cfg.Level = zap.NewAtomicLevelAt(zap.WarnLevel) + l, _ := cfg.Build() + return l.Sugar() +} + +// ============================================================ +// Tests for getJobsMap +// ============================================================ + +func TestGetJobsMap_EmptyInput(t *testing.T) { + bs := &BackgroundStorage{l: newTestLogger()} + result := bs.getJobsMap(nil) + assert.NotNil(t, result) + assert.Empty(t, result) + + result2 := bs.getJobsMap([]stat_activity.SessionPid{}) + assert.NotNil(t, result2) + assert.Empty(t, result2) +} + +func TestGetJobsMap_SingleHost(t *testing.T) { + storage.SetHostnameForSegindex(10, "host-a") + + sessions := []stat_activity.SessionPid{ + {GpSegmentId: 10, Pid: 100, SessId: 1}, + {GpSegmentId: 10, Pid: 200, SessId: 2}, + } + + bs := &BackgroundStorage{l: newTestLogger()} + result := bs.getJobsMap(sessions) + + // The map should contain an entry for "host-a" + _, exists := result["host-a"] + assert.True(t, exists, "expected key 'host-a' in hostJobMap") +} + +func TestGetJobsMap_MultipleHosts(t *testing.T) { + storage.SetHostnameForSegindex(20, "host-b") + storage.SetHostnameForSegindex(21, "host-c") + + sessions := []stat_activity.SessionPid{ + {GpSegmentId: 20, Pid: 100, SessId: 1}, + {GpSegmentId: 21, Pid: 200, SessId: 2}, + {GpSegmentId: 20, Pid: 300, SessId: 3}, + } + + bs := &BackgroundStorage{l: newTestLogger()} + result := bs.getJobsMap(sessions) + + // Should have entries for both hosts + assert.Contains(t, result, "host-b") + assert.Contains(t, result, "host-c") +} + +func TestGetJobsMap_UnknownSegindex(t *testing.T) { + // When segindex is not in the config storage, GetHostnameForSegindex + // returns the string representation of the segindex. + sessions := []stat_activity.SessionPid{ + {GpSegmentId: 9999, Pid: 100, SessId: 1}, + } + + bs := &BackgroundStorage{l: newTestLogger()} + result := bs.getJobsMap(sessions) + _, exists := result["9999"] + assert.True(t, exists, "expected key '9999' for unknown segindex") +} + +// ============================================================ +// Tests for processProcfsRequests +// ============================================================ + +func TestProcessProcfsRequests_Success(t *testing.T) { + fakeSrv := &fakeProcStatServer{} + lis := setupBufconnServer(t, fakeSrv) + + // Inject the bufconn connection into the global connection cache + conn := dialBufconn(t, lis) + hostname := fmt.Sprintf("test-procfs-success-%d", time.Now().UnixNano()) + segConnectionLock.Lock() + segConnections[hostname] = conn + segConnectionLock.Unlock() + t.Cleanup(func() { + segConnectionLock.Lock() + delete(segConnections, hostname) + segConnectionLock.Unlock() + }) + + reqs := []stat_activity.SessionPid{ + {GpSegmentId: 1, Pid: 100, SessId: 10}, + {GpSegmentId: 2, Pid: 200, SessId: 20}, + } + + bs := &BackgroundStorage{l: newTestLogger()} + ctx := context.Background() + err := bs.processProcfsRequests(ctx, hostname, 0, 5*time.Second, 4*1024*1024, reqs) + require.NoError(t, err) + called, lastReq := fakeSrv.snapshot() + assert.True(t, called, "expected GetPidProcStat to be called") + require.NotNil(t, lastReq) + assert.Len(t, lastReq.SegmentProcess, 2) + + // Verify the proto message fields + sp0 := lastReq.SegmentProcess[0] + assert.Equal(t, int64(1), sp0.GpSegmentId) + assert.Equal(t, int64(100), sp0.Pid) + assert.Equal(t, int64(10), sp0.SessId) + + sp1 := lastReq.SegmentProcess[1] + assert.Equal(t, int64(2), sp1.GpSegmentId) + assert.Equal(t, int64(200), sp1.Pid) + assert.Equal(t, int64(20), sp1.SessId) +} + +func TestProcessProcfsRequests_GrpcError(t *testing.T) { + failSrv := &failingProcStatServer{} + lis := setupBufconnServer(t, failSrv) + + conn := dialBufconn(t, lis) + hostname := fmt.Sprintf("test-procfs-fail-%d", time.Now().UnixNano()) + segConnectionLock.Lock() + segConnections[hostname] = conn + segConnectionLock.Unlock() + t.Cleanup(func() { + segConnectionLock.Lock() + delete(segConnections, hostname) + segConnectionLock.Unlock() + }) + + reqs := []stat_activity.SessionPid{ + {GpSegmentId: 1, Pid: 100, SessId: 10}, + } + + bs := &BackgroundStorage{l: newTestLogger()} + ctx := context.Background() + err := bs.processProcfsRequests(ctx, hostname, 0, 5*time.Second, 4*1024*1024, reqs) + require.Error(t, err) + assert.Contains(t, err.Error(), "simulated gRPC error") +} + +func TestProcessProcfsRequests_CancelledContext(t *testing.T) { + fakeSrv := &fakeProcStatServer{} + lis := setupBufconnServer(t, fakeSrv) + + conn := dialBufconn(t, lis) + hostname := fmt.Sprintf("test-procfs-cancel-%d", time.Now().UnixNano()) + segConnectionLock.Lock() + segConnections[hostname] = conn + segConnectionLock.Unlock() + t.Cleanup(func() { + segConnectionLock.Lock() + delete(segConnections, hostname) + segConnectionLock.Unlock() + }) + + reqs := []stat_activity.SessionPid{ + {GpSegmentId: 1, Pid: 100, SessId: 10}, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + // With a cancelled context, processProcfsRequests should skip building + // the request body (due to select on ctx.Done()) and return nil. + bs := &BackgroundStorage{l: newTestLogger()} + err := bs.processProcfsRequests(ctx, hostname, 0, 5*time.Second, 4*1024*1024, reqs) + // The function returns nil when context is cancelled during request building, + // but may return an error from the gRPC call if the request was already built. + // Either outcome is acceptable with a cancelled context. + if err != nil { + assert.ErrorIs(t, ctx.Err(), context.Canceled) + } +} + +func TestProcessProcfsRequests_EmptyRequests(t *testing.T) { + fakeSrv := &fakeProcStatServer{} + lis := setupBufconnServer(t, fakeSrv) + + conn := dialBufconn(t, lis) + hostname := fmt.Sprintf("test-procfs-empty-%d", time.Now().UnixNano()) + segConnectionLock.Lock() + segConnections[hostname] = conn + segConnectionLock.Unlock() + t.Cleanup(func() { + segConnectionLock.Lock() + delete(segConnections, hostname) + segConnectionLock.Unlock() + }) + + bs := &BackgroundStorage{l: newTestLogger()} + ctx := context.Background() + err := bs.processProcfsRequests(ctx, hostname, 0, 5*time.Second, 4*1024*1024, nil) + require.NoError(t, err) + called, _ := fakeSrv.snapshot() + assert.False(t, called, "GetPidProcStat should not be called with empty segment list") +} + +// ============================================================ +// Tests for GatherProcfsStat +// ============================================================ + +func TestGatherProcfsStat_ListAllSessionsError(t *testing.T) { + mock := &mockStatActivityLister{ + sessionsErr: fmt.Errorf("db connection failed"), + } + bs := &BackgroundStorage{ + l: newTestLogger(), + statActivityLister: mock, + } + + err := bs.GatherProcfsStat(context.Background(), 2, 50051, 5*time.Second, 4*1024*1024) + require.Error(t, err) + assert.Contains(t, err.Error(), "db connection failed") + assert.True(t, mock.listCalled) +} + +func TestGatherProcfsStat_EmptySessions(t *testing.T) { + mock := &mockStatActivityLister{ + sessions: []stat_activity.SessionPid{}, + } + bs := &BackgroundStorage{ + l: newTestLogger(), + statActivityLister: mock, + } + + err := bs.GatherProcfsStat(context.Background(), 2, 50051, 5*time.Second, 4*1024*1024) + require.NoError(t, err) + assert.True(t, mock.listCalled) +} + +func TestGatherProcfsStat_WithSessions(t *testing.T) { + // Set up a fake gRPC server + fakeSrv := &fakeProcStatServer{} + lis := setupBufconnServer(t, fakeSrv) + + conn := dialBufconn(t, lis) + hostname := fmt.Sprintf("test-gather-%d", time.Now().UnixNano()) + + // Register the hostname in the segment config + storage.SetHostnameForSegindex(30, hostname) + + // Inject the bufconn connection + segConnectionLock.Lock() + segConnections[hostname] = conn + segConnectionLock.Unlock() + t.Cleanup(func() { + segConnectionLock.Lock() + delete(segConnections, hostname) + segConnectionLock.Unlock() + }) + + mock := &mockStatActivityLister{ + sessions: []stat_activity.SessionPid{ + {GpSegmentId: 30, Pid: 100, SessId: 1}, + {GpSegmentId: 30, Pid: 200, SessId: 2}, + }, + } + bs := &BackgroundStorage{ + l: newTestLogger(), + statActivityLister: mock, + } + + err := bs.GatherProcfsStat(context.Background(), 2, 0, 5*time.Second, 4*1024*1024) + require.NoError(t, err) + assert.True(t, mock.listCalled) +} + +func TestGatherProcfsStat_ContextCancelled(t *testing.T) { + mock := &mockStatActivityLister{ + sessions: []stat_activity.SessionPid{ + {GpSegmentId: 40, Pid: 100, SessId: 1}, + }, + } + bs := &BackgroundStorage{ + l: newTestLogger(), + statActivityLister: mock, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + // With a cancelled context, the timeout context creation will produce + // an already-done context, so the pool tasks should handle it gracefully. + err := bs.GatherProcfsStat(ctx, 2, 50051, 5*time.Second, 4*1024*1024) + // The error may be nil (if tasks detect cancellation early) or non-nil + // (if the gRPC call fails due to cancelled context). Both are acceptable. + _ = err +} + +func TestGatherProcfsStat_GrpcFailure(t *testing.T) { + failSrv := &failingProcStatServer{} + lis := setupBufconnServer(t, failSrv) + + conn := dialBufconn(t, lis) + hostname := fmt.Sprintf("test-gather-fail-%d", time.Now().UnixNano()) + + storage.SetHostnameForSegindex(50, hostname) + + segConnectionLock.Lock() + segConnections[hostname] = conn + segConnectionLock.Unlock() + t.Cleanup(func() { + segConnectionLock.Lock() + delete(segConnections, hostname) + segConnectionLock.Unlock() + }) + + mock := &mockStatActivityLister{ + sessions: []stat_activity.SessionPid{ + {GpSegmentId: 50, Pid: 100, SessId: 1}, + }, + } + bs := &BackgroundStorage{ + l: newTestLogger(), + statActivityLister: mock, + } + + err := bs.GatherProcfsStat(context.Background(), 2, 0, 5*time.Second, 4*1024*1024) + require.Error(t, err) + assert.Contains(t, err.Error(), "simulated gRPC error") +} + +func TestGatherProcfsStat_ManySessionsBatching(t *testing.T) { + // Create more than JobsPerQuery sessions to verify batching logic + fakeSrv := &fakeProcStatServer{} + lis := setupBufconnServer(t, fakeSrv) + + conn := dialBufconn(t, lis) + hostname := fmt.Sprintf("test-gather-batch-%d", time.Now().UnixNano()) + + storage.SetHostnameForSegindex(60, hostname) + + segConnectionLock.Lock() + segConnections[hostname] = conn + segConnectionLock.Unlock() + t.Cleanup(func() { + segConnectionLock.Lock() + delete(segConnections, hostname) + segConnectionLock.Unlock() + }) + + // Create JobsPerQuery + 5 sessions to trigger at least 2 batches + sessions := make([]stat_activity.SessionPid, 0, jobsPerQuery+5) + for i := 0; i < jobsPerQuery+5; i++ { + sessions = append(sessions, stat_activity.SessionPid{ + GpSegmentId: 60, + Pid: 100 + i, + SessId: i + 1, + }) + } + + mock := &mockStatActivityLister{ + sessions: sessions, + } + bs := &BackgroundStorage{ + l: newTestLogger(), + statActivityLister: mock, + } + + err := bs.GatherProcfsStat(context.Background(), 4, 0, 10*time.Second, 4*1024*1024) + require.NoError(t, err) + assert.True(t, mock.listCalled) + // The fake server should have been called (at least once for the batches) + called, _ := fakeSrv.snapshot() + assert.True(t, called) +} + +// ============================================================ +// Tests for constants +// ============================================================ + +func TestConstants(t *testing.T) { + assert.Equal(t, 1000, jobsPerQuery) +} + +func TestGatherProcfsStat_InvalidNPullers(t *testing.T) { + mock := &mockStatActivityLister{} + bs := &BackgroundStorage{ + l: newTestLogger(), + statActivityLister: mock, + } + + for _, n := range []int{0, -1, -100} { + err := bs.GatherProcfsStat(context.Background(), n, 50051, 5*time.Second, 4*1024*1024) + require.Error(t, err) + assert.Contains(t, err.Error(), "nPullers must be greater than 0") + assert.False(t, mock.listCalled, "ListAllSessions should not be called for invalid nPullers") + } +} diff --git a/internal/master/utils.go b/internal/master/utils.go new file mode 100644 index 0000000..c3469f2 --- /dev/null +++ b/internal/master/utils.go @@ -0,0 +1,57 @@ +package master + +import ( + "context" + "net" + "strconv" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" +) + +var ( + segConnections = make(map[string]*grpc.ClientConn) + segConnectionLock sync.Mutex +) + +func getGrpcClientConnection(ctx context.Context, hostname string, portn uint32, segConnectTimeoutSec float64) (*grpc.ClientConn, error) { + var err error + segConnectionLock.Lock() + defer segConnectionLock.Unlock() + conn, ok := segConnections[hostname] + if ok { + if conn.GetState() != connectivity.Shutdown { + return conn, nil + } + } + connectTimeout := time.Second * time.Duration(segConnectTimeoutSec) + if portn > 0 { + conn, err = grpc.NewClient( + net.JoinHostPort(hostname, strconv.FormatUint(uint64(portn), 10)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithConnectParams(grpc.ConnectParams{ + MinConnectTimeout: connectTimeout, + }), + ) + } else { + conn, err = grpc.NewClient( + hostname, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", addr) + }), + grpc.WithConnectParams(grpc.ConnectParams{ + MinConnectTimeout: connectTimeout, + }), + ) + } + if err != nil { + return nil, err + } + segConnections[hostname] = conn + return conn, nil +}