diff --git a/admin/.env b/admin/.env new file mode 100644 index 0000000..d5bdf88 --- /dev/null +++ b/admin/.env @@ -0,0 +1,2 @@ +SERVER_HOST=localhost +SERVER_PORT=50051 \ No newline at end of file diff --git a/admin/README.md b/admin/README.md new file mode 100644 index 0000000..1a96737 --- /dev/null +++ b/admin/README.md @@ -0,0 +1,26 @@ +# Admin Client + +## Установка зависимостей +```bash +pip install -r requirements.txt +``` + +## Генерация proto-файлов +```bash +python -m grpc_tools.protoc \ + --python_out=. \ + --grpc_python_out=. \ + -I . \ + admin_service.proto +``` + +## Запуск контроллера +```bash +cd ../controller +bazel run //cmd/grpc_server:grpc_server +``` + +## Запуск клиента +```bash +python admin.py --file config.toml +``` \ No newline at end of file diff --git a/admin/admin.py b/admin/admin.py new file mode 100644 index 0000000..016ce10 --- /dev/null +++ b/admin/admin.py @@ -0,0 +1,42 @@ +import argparse +import grpc +from dotenv import load_dotenv +import os + +import admin_service_pb2 +import admin_service_pb2_grpc + +load_dotenv(".env") + +class AdminClient: + def __init__(self): + self.host: str = os.environ["SERVER_HOST"] + self.port: str = os.environ["SERVER_PORT"] + self.channel: grpc.channel = grpc.insecure_channel(f"{self.host}:{self.port}") + self.stub: admin_service_pb2_grpc.AdminServiceStub = admin_service_pb2_grpc.AdminServiceStub(self.channel) + + def load_config(self, toml_file: str) -> admin_service_pb2.LoadConfigResponse: + with open(toml_file, "rb") as f: + content: bytes = f.read() + + request: admin_service_pb2.LoadConfigRequest = (admin_service_pb2.LoadConfigRequest(config_data=content)) + return self.stub.LoadConfig(request) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--file", default="config.toml") + args = parser.parse_args() + + try: + client = AdminClient() + client.load_config(args.file) + print("Config loaded") + except Exception as e: + print(f"Error loading config: {e}") + return 1 + + return 0 + +if __name__ == "__main__": + main() + diff --git a/admin/admin_service.proto b/admin/admin_service.proto new file mode 100644 index 0000000..c587721 --- /dev/null +++ b/admin/admin_service.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package admin_service; + +option go_package = "github.com/moevm/grpc_server/pkg/proto/admin_service"; + +service AdminService { + rpc LoadConfig(LoadConfigRequest) returns (LoadConfigResponse) {} +} + +message LoadConfigRequest { + bytes config_data = 1; +} + +message LoadConfigResponse { + bool success = 1; +} diff --git a/admin/config.toml b/admin/config.toml new file mode 100644 index 0000000..3964b87 --- /dev/null +++ b/admin/config.toml @@ -0,0 +1,22 @@ +[global.rules] +block_categories = ["MALWARE", "SOCIAL"] +block_domains = ["youtube.com", "tiktok.com"] +allow_domains = ["github.com", "stackoverflow.com"] + +[global.rules.block_by_trust] +ENTERTAINMENT = 6 +NEWS = 4 + +[filters.filter_1] +block_categories = ["SOCIAL", "ENTERTAINMENT"] +block_domains = ["instagram.com"] +allow_domains = ["vk.com"] + +[filters.filter_1.block_by_trust] +SOCIAL = 8 +ENTERTAINMENT = 7 + +[filters.filter_2] +block_categories = ["MALWARE"] +allow_domains = ["github.com", "gitlab.com"] +min_trust_level = 3 \ No newline at end of file diff --git a/admin/requirements.txt b/admin/requirements.txt new file mode 100644 index 0000000..3bf5f6e --- /dev/null +++ b/admin/requirements.txt @@ -0,0 +1,4 @@ +grpcio==1.60.0 +grpcio-tools==1.60.0 +protobuf==4.25.1 +python-dotenv==1.0.0 \ No newline at end of file diff --git a/controller/MODULE.bazel b/controller/MODULE.bazel index 2efd7fc..06b9ce2 100644 --- a/controller/MODULE.bazel +++ b/controller/MODULE.bazel @@ -25,6 +25,7 @@ use_repo( go_deps, "com_github_joho_godotenv", "com_github_stretchr_testify", + "com_github_pelletier_go_toml", "org_golang_google_grpc", "org_golang_google_protobuf", "org_golang_x_text", diff --git a/controller/cmd/grpc_server/BUILD b/controller/cmd/grpc_server/BUILD index 2c14b95..6923a67 100644 --- a/controller/cmd/grpc_server/BUILD +++ b/controller/cmd/grpc_server/BUILD @@ -15,7 +15,9 @@ go_library( "//internal/config", "//internal/grpcserver", "//internal/manager", + "//pkg/proto/admin_service:admin_service_go_proto", "//pkg/proto/file_service:file_service_go_proto", + "//pkg/proto/communication:communication_go_proto", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//reflection", ], diff --git a/controller/cmd/grpc_server/main.go b/controller/cmd/grpc_server/main.go index 7b5f99f..46d4ffd 100644 --- a/controller/cmd/grpc_server/main.go +++ b/controller/cmd/grpc_server/main.go @@ -10,16 +10,22 @@ import ( pb "github.com/moevm/grpc_server/pkg/proto/file_service" "google.golang.org/grpc" "google.golang.org/grpc/reflection" + adminPb "github.com/moevm/grpc_server/pkg/proto/admin_service" + commPb "github.com/moevm/grpc_server/pkg/proto/communication" ) func main() { cfg := config.Load() - + adminServer := grpcserver.NewAdminServer() mgr, err := manager.NewManager() if err != nil { log.Fatalf("manager.NewManager(): %v", err) } + adminServer.SetManager(mgr) + + dataServer := grpcserver.NewDataServer(mgr) + lis, err := net.Listen("tcp", net.JoinHostPort(cfg.Host, cfg.Port)) if err != nil { log.Fatalf("failed to listen: %v", err) @@ -31,7 +37,9 @@ func main() { } service := grpc.NewServer(serverOpts...) + adminPb.RegisterAdminServiceServer(service, adminServer) pb.RegisterFileServiceServer(service, grpcserver.NewServer(cfg.AllowedChars, mgr)) + commPb.RegisterDataServiceServer(service, dataServer) reflection.Register(service) log.Printf("Server starting on %s:%s", cfg.Host, cfg.Port) diff --git a/controller/go.mod b/controller/go.mod index e2a0fff..306b484 100644 --- a/controller/go.mod +++ b/controller/go.mod @@ -11,6 +11,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pelletier/go-toml v1.9.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/net v0.40.0 // indirect golang.org/x/sys v0.33.0 // indirect diff --git a/controller/go.sum b/controller/go.sum index 3dd6f80..9f76e60 100644 --- a/controller/go.sum +++ b/controller/go.sum @@ -12,6 +12,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= +github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/controller/internal/grpcserver/BUILD b/controller/internal/grpcserver/BUILD index ffc12e3..818872c 100644 --- a/controller/internal/grpcserver/BUILD +++ b/controller/internal/grpcserver/BUILD @@ -2,21 +2,29 @@ load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "grpcserver", - srcs = ["server.go"], + srcs = [ + "server.go", + "admin_server.go", + "data_server.go", + ], importpath = "github.com/moevm/grpc_server/internal/grpcserver", visibility = ["//visibility:public"], deps = [ "//internal/manager", "//pkg/proto/file_service:file_service_go_proto", + "//pkg/proto/admin_service:admin_service_go_proto", + "//pkg/proto/communication:communication_go_proto", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", - ], + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_protobuf//types/known/emptypb", + ] ) go_test( name = "grpcserver_test", srcs = ["server_test.go"], - embed = [ ":grpcserver" ], + embed = [":grpcserver"], deps = [ "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/controller/internal/grpcserver/admin_server.go b/controller/internal/grpcserver/admin_server.go new file mode 100644 index 0000000..386bb85 --- /dev/null +++ b/controller/internal/grpcserver/admin_server.go @@ -0,0 +1,37 @@ +package grpcserver + +import ( + "context" + "github.com/moevm/grpc_server/internal/manager" + pb "github.com/moevm/grpc_server/pkg/proto/admin_service" +) + +type AdminServer struct { + pb.UnimplementedAdminServiceServer + configData []byte + manager *manager.Manager +} + +func NewAdminServer() *AdminServer { + return &AdminServer{} +} + +func (s *AdminServer) SetManager(m *manager.Manager) { + s.manager = m +} + +func (s *AdminServer) LoadConfig(ctx context.Context, req *pb.LoadConfigRequest) (*pb.LoadConfigResponse, error) { + s.configData = req.ConfigData + + if s.manager != nil { + s.manager.UpdateConfig(s.configData) + } + + return &pb.LoadConfigResponse{ + Success: true, + }, nil +} + +func (s *AdminServer) GetConfig() []byte { + return s.configData +} diff --git a/controller/internal/grpcserver/data_server.go b/controller/internal/grpcserver/data_server.go new file mode 100644 index 0000000..48af5c2 --- /dev/null +++ b/controller/internal/grpcserver/data_server.go @@ -0,0 +1,44 @@ +package grpcserver + +import ( + "context" + "log" + + "github.com/moevm/grpc_server/internal/manager" + pb "github.com/moevm/grpc_server/pkg/proto/communication" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" +) + +type DataServer struct { + pb.UnimplementedDataServiceServer + manager *manager.Manager +} + +func NewDataServer(mgr *manager.Manager) *DataServer { + return &DataServer{manager: mgr} +} + +func (s *DataServer) GetPolicy(ctx context.Context, req *pb.GetPolicyRequest) (*pb.WorkerPolicy, error) { + log.Printf("gRPC GetPolicy from worker %d (hash: %d)", req.WorkerId, req.PolicyHash) + policy := s.manager.GetWorkerPolicy(req.WorkerId) + if policy == nil { + return nil, status.Errorf(codes.NotFound, "no policy for worker %d", req.WorkerId) + } + return policy, nil +} + +func (s *DataServer) Classify(ctx context.Context, req *pb.ClassifyRequest) (*pb.ClassifyResponse, error) { + log.Printf("gRPC Classify from worker %d for domain: %s", req.WorkerId, req.Domain) + return &pb.ClassifyResponse{ + Categories: []string{"unknown"}, + TrustLevel: 50, + }, nil +} + +func (s *DataServer) SendStats(ctx context.Context, report *pb.StatsReport) (*emptypb.Empty, error) { + log.Printf("gRPC Stats from worker %d: blocked=%d allowed=%d", + report.WorkerId, report.TotalBlocked, report.TotalAllowed) + return &emptypb.Empty{}, nil +} \ No newline at end of file diff --git a/controller/internal/manager/BUILD b/controller/internal/manager/BUILD index fc18761..91f216a 100644 --- a/controller/internal/manager/BUILD +++ b/controller/internal/manager/BUILD @@ -2,14 +2,15 @@ load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "manager", - srcs = ["manager.go"], + srcs = ["manager.go", "policy_manager.go",], importpath = "github.com/moevm/grpc_server/internal/manager", visibility = ["//visibility:public"], deps = [ "//internal/conn", - "//pkg/converter", "//pkg/proto/communication:communication_go_proto", "@org_golang_google_protobuf//proto", "@com_github_joho_godotenv//:go_default_library", + "@com_github_pelletier_go_toml//:go_default_library", + "@org_golang_google_protobuf//types/known/structpb", ], ) diff --git a/controller/internal/manager/manager.go b/controller/internal/manager/manager.go index 92b635b..ccf45b3 100644 --- a/controller/internal/manager/manager.go +++ b/controller/internal/manager/manager.go @@ -23,7 +23,6 @@ import ( "github.com/moevm/grpc_server/internal/conn" "google.golang.org/protobuf/proto" - communication "github.com/moevm/grpc_server/pkg/proto/communication" ) @@ -38,7 +37,7 @@ const ( const ( workerMainSocketPath = "/run/controller/main.sock" - workerSocketPath = "/run/controller/" + workerSocketPath = "/run/controller/" ) type IManager interface { @@ -49,9 +48,10 @@ type IManager interface { type Manager struct { listener net.Listener - workers map[uint64]*Worker - workersMutex sync.Mutex - workerId uint64 + policyManager *PolicyManager + workers map[uint64]*Worker + workersMutex sync.Mutex + workerId uint64 tasks map[uint64]*Task tasksMutex sync.Mutex @@ -328,6 +328,7 @@ func (m *Manager) mainLoop() { } } + // errorHandler catches the errors from goroutines and logs them. // This function should always be run in a goroutine. func (m *Manager) errorHandler() { @@ -504,6 +505,33 @@ func removeContents(dir string) error { return nil } +func (m *Manager) HandleGetPolicy(workerID uint64, currentHash uint64) ([]byte, bool, error) { + log.Printf("Worker %d requested policy", workerID) + + policyProto := m.policyManager.GetWorkerPolicyProto(workerID) + + if currentHash == policyProto.PolicyHash { + return nil, false, nil + } + + policyBytes, err := proto.Marshal(policyProto) + if err != nil { + return nil, false, err + } + + return policyBytes, true, nil +} + +func (m *Manager) GetWorkerPolicy(workerID uint64) *communication.WorkerPolicy { + return m.policyManager.GetWorkerPolicyProto(workerID) +} + +func (m *Manager) UpdateConfig(configData []byte) { + if m.policyManager != nil { + m.policyManager.UpdateConfig(configData) + } +} + func NewManager() (*Manager, error) { if err := removeContents(workerSocketPath); err != nil { return nil, fmt.Errorf("failed to clean socket directory: %w", err) @@ -514,8 +542,10 @@ func NewManager() (*Manager, error) { return nil, fmt.Errorf("failed to create listener: %w", err) } + m := &Manager{ listener: listener, + policyManager: NewPolicyManager(), workers: make(map[uint64]*Worker), tasks: make(map[uint64]*Task), freeWorkers: make(chan uint64, 32), @@ -546,4 +576,4 @@ func (m *Manager) Shutdown() { close(m.taskSolutions) close(m.errorChan) }) -} +} \ No newline at end of file diff --git a/controller/internal/manager/policy_manager.go b/controller/internal/manager/policy_manager.go new file mode 100644 index 0000000..efbaf3e --- /dev/null +++ b/controller/internal/manager/policy_manager.go @@ -0,0 +1,137 @@ +package manager + +import ( + "crypto/sha256" + "encoding/binary" + "fmt" + "log" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + "github.com/pelletier/go-toml" + pb "github.com/moevm/grpc_server/pkg/proto/communication" +) + +type TOMLRules struct { + BlockCategories []string `toml:"block_categories"` + BlockByTrust map[string]int32 `toml:"block_by_trust"` + BlockDomains []string `toml:"block_domains"` + AllowDomains []string `toml:"allow_domains"` + MinTrustLevel int32 `toml:"min_trust_level"` + Extra map[string]interface{} `toml:",remain"` +} + +type TOMLConfig struct { + Global struct { + Rules TOMLRules `toml:"rules"` + } `toml:"global"` + Filters map[string]TOMLRules `toml:"filters"` +} + +type PolicyManager struct { + config *TOMLConfig + version uint64 +} + +func NewPolicyManager() *PolicyManager { + return &PolicyManager{ + version: 0, + config: &TOMLConfig{}, + } +} + +func computeHash(policy *pb.WorkerPolicy) uint64 { + data, _ := proto.Marshal(policy) + hash := sha256.Sum256(data) + return binary.BigEndian.Uint64(hash[:8]) +} + +func (pm *PolicyManager) GetWorkerPolicyProto(workerID uint64) *pb.WorkerPolicy { + if pm.config == nil { + return &pb.WorkerPolicy{} + } + + policy := &pb.WorkerPolicy{ + BlockCategories: make([]string, len(pm.config.Global.Rules.BlockCategories)), + BlockByTrust: make(map[string]int32, len(pm.config.Global.Rules.BlockByTrust)), + BlockDomains: make([]string, len(pm.config.Global.Rules.BlockDomains)), + AllowDomains: make([]string, len(pm.config.Global.Rules.AllowDomains)), + MinTrustLevel: pm.config.Global.Rules.MinTrustLevel, + } + copy(policy.BlockCategories, pm.config.Global.Rules.BlockCategories) + for k, v := range pm.config.Global.Rules.BlockByTrust { + policy.BlockByTrust[k] = v + } + copy(policy.BlockDomains, pm.config.Global.Rules.BlockDomains) + copy(policy.AllowDomains, pm.config.Global.Rules.AllowDomains) + + filterName := fmt.Sprintf("filter_%d", workerID) + if filter, ok := pm.config.Filters[filterName]; ok { + if len(filter.BlockCategories) > 0 { + existing := make(map[string]bool) + for _, cat := range policy.BlockCategories { + existing[cat] = true + } + for _, cat := range filter.BlockCategories { + if !existing[cat] { + policy.BlockCategories = append(policy.BlockCategories, cat) + } + } + } + + for k, v := range filter.BlockByTrust { + policy.BlockByTrust[k] = v + } + + if len(filter.BlockDomains) > 0 { + existing := make(map[string]bool) + for _, d := range policy.BlockDomains { + existing[d] = true + } + for _, d := range filter.BlockDomains { + if !existing[d] { + policy.BlockDomains = append(policy.BlockDomains, d) + } + } + } + + if len(filter.AllowDomains) > 0 { + existing := make(map[string]bool) + for _, d := range policy.AllowDomains { + existing[d] = true + } + for _, d := range filter.AllowDomains { + if !existing[d] { + policy.AllowDomains = append(policy.AllowDomains, d) + } + } + } + if filter.MinTrustLevel != 0 { + policy.MinTrustLevel = filter.MinTrustLevel + } + + if len(filter.Extra) > 0 { + if extraStruct, err := structpb.NewStruct(filter.Extra); err == nil { + policy.Extra = extraStruct + } + } else { + if len(pm.config.Global.Rules.Extra) > 0 { + if extraStruct, err := structpb.NewStruct(pm.config.Global.Rules.Extra); err == nil { + policy.Extra = extraStruct + } + } + } + } + policy.PolicyHash = computeHash(policy) + return policy +} + +func (pm *PolicyManager) UpdateConfig(configData []byte) { + var cfg TOMLConfig + if err := toml.Unmarshal(configData, &cfg); err != nil { + log.Printf("Failed to parse TOML in UpdateConfig: %v", err) + return + } + pm.config = &cfg + pm.version++ + log.Printf("Config updated to version %d", pm.version) +} diff --git a/controller/pkg/proto/admin_service/BUILD b/controller/pkg/proto/admin_service/BUILD new file mode 100644 index 0000000..8ee5046 --- /dev/null +++ b/controller/pkg/proto/admin_service/BUILD @@ -0,0 +1,17 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_go//proto:def.bzl", "go_proto_library") + +proto_library( + name = "admin_service_proto", + srcs = ["admin_service.proto"], + visibility = ["//visibility:public"], +) + + +go_proto_library( + name = "admin_service_go_proto", + compilers = ["@rules_go//proto:go_grpc"], + importpath = "github.com/moevm/grpc_server/pkg/proto/admin_service", + proto = ":admin_service_proto", + visibility = ["//visibility:public"], +) diff --git a/controller/pkg/proto/admin_service/admin_service.proto b/controller/pkg/proto/admin_service/admin_service.proto new file mode 100644 index 0000000..c587721 --- /dev/null +++ b/controller/pkg/proto/admin_service/admin_service.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package admin_service; + +option go_package = "github.com/moevm/grpc_server/pkg/proto/admin_service"; + +service AdminService { + rpc LoadConfig(LoadConfigRequest) returns (LoadConfigResponse) {} +} + +message LoadConfigRequest { + bytes config_data = 1; +} + +message LoadConfigResponse { + bool success = 1; +} diff --git a/controller/pkg/proto/communication/BUILD b/controller/pkg/proto/communication/BUILD index d6785d4..d967a96 100644 --- a/controller/pkg/proto/communication/BUILD +++ b/controller/pkg/proto/communication/BUILD @@ -5,6 +5,10 @@ proto_library( name = "communication_proto", srcs = ["communication.proto"], visibility = ["//visibility:public"], + deps = [ + "@protobuf//:struct_proto", + "@protobuf//:empty_proto", + ], ) go_proto_library( diff --git a/controller/pkg/proto/communication/communication.proto b/controller/pkg/proto/communication/communication.proto index 1330534..185fc54 100644 --- a/controller/pkg/proto/communication/communication.proto +++ b/controller/pkg/proto/communication/communication.proto @@ -1,5 +1,13 @@ syntax = "proto3"; option go_package = "github.com/moevm/grpc_server/pkg/proto/communication"; +import "google/protobuf/struct.proto"; +import "google/protobuf/empty.proto"; + +service DataService { + rpc GetPolicy(GetPolicyRequest) returns (WorkerPolicy); + rpc Classify(ClassifyRequest) returns (ClassifyResponse); + rpc SendStats(StatsReport) returns (google.protobuf.Empty); +} enum PulseType { PULSE_INVALD = 0; @@ -57,3 +65,44 @@ message WorkerResponse { uint64 task_id = 2; uint64 extra_size = 3; } + +message GetPolicyRequest { + uint64 worker_id = 1; + uint64 policy_hash = 2; + uint64 config_version = 3; +} + +message WorkerPolicy { + repeated string block_categories = 1; + map block_by_trust = 2; + repeated string block_domains = 3; + repeated string allow_domains = 4; + int32 min_trust_level = 5; + uint64 config_version = 6; + uint64 policy_hash = 7; + google.protobuf.Struct extra = 8; +} + +message ClassifyRequest { + uint64 worker_id = 1; + string domain = 2; +} + +message ClassifyResponse { + repeated string categories = 1; + int32 trust_level = 2; +} + +message ResourceStats { + string domain = 1; + uint64 blocked = 2; + uint64 allowed = 3; +} + +message StatsReport { + uint64 worker_id = 1; + uint64 time = 2; + uint64 total_blocked = 3; + uint64 total_allowed = 4; + repeated ResourceStats resources = 5; +} diff --git a/controller/test/BUILD b/controller/test/BUILD new file mode 100644 index 0000000..bcffe54 --- /dev/null +++ b/controller/test/BUILD @@ -0,0 +1,11 @@ +load("@rules_go//go:def.bzl", "go_test") + +go_test( + name = "integration_test", + srcs = ["worker_test.go"], + deps = [ + "@com_github_stretchr_testify//assert", + ], + args = ["--test.v"], + tags = ["exclusive"], +) \ No newline at end of file diff --git a/controller/test/README.md b/controller/test/README.md new file mode 100644 index 0000000..159ba36 --- /dev/null +++ b/controller/test/README.md @@ -0,0 +1,18 @@ +# Запуск теста + +### 1. Сборка компонентов + +```bash +cd controller +bazel build //cmd/grpc_server:grpc_server + +cd ../worker +bazel build //:worker +``` + +### 2. Запуск интеграционного теста + +```bash +cd ../controller +./test/run.sh +``` \ No newline at end of file diff --git a/controller/test/run.sh b/controller/test/run.sh new file mode 100755 index 0000000..ff2d1d7 --- /dev/null +++ b/controller/test/run.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +set -e + +PROJECT_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" + +cd "$PROJECT_ROOT/controller" + +echo "Сборка теста" +bazel build //test:integration_test + +echo "Запуск" +./bazel-bin/test/integration_test_/integration_test diff --git a/controller/test/worker_test.go b/controller/test/worker_test.go new file mode 100644 index 0000000..481bbf7 --- /dev/null +++ b/controller/test/worker_test.go @@ -0,0 +1,102 @@ +package test + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func findProjectRoot() string { + dir, _ := os.Getwd() + for { + if _, err := os.Stat(filepath.Join(dir, "controller")); err != nil { + if _, err := os.Stat(filepath.Join(dir, "worker")); err != nil { + dir = filepath.Dir(dir) + continue + } + } + return dir + } +} + +func TestWorkerPolicyRequest(t *testing.T) { + root := findProjectRoot() + + ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") + workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") + + ctrl := exec.Command(ctrlBin) + ctrl.Start() + defer ctrl.Process.Kill() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + worker := exec.CommandContext(ctx, workerBin) + worker.Env = []string{ + "METRICS_GATEWAY_ADDRESS=localhost", + "METRICS_GATEWAY_PORT=9091", + "TEST_REQUEST_POLICY=true", + } + + output, err := worker.CombinedOutput() + assert.NoError(t, err, "Worker failed") + assert.Contains(t, string(output), "Worker 1 requests policy") + assert.Contains(t, string(output), "Policy received") +} + +func TestWorkerStatsReport(t *testing.T) { + root := findProjectRoot() + + ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") + workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") + + ctrl := exec.Command(ctrlBin) + ctrl.Start() + defer ctrl.Process.Kill() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + worker := exec.CommandContext(ctx, workerBin) + worker.Env = []string{ + "METRICS_GATEWAY_ADDRESS=localhost", + "METRICS_GATEWAY_PORT=9091", + "TEST_STATS=true", + } + + output, err := worker.CombinedOutput() + assert.NoError(t, err, "Worker failed") + assert.Contains(t, string(output), "Worker 1 send stats") + assert.Contains(t, string(output), "Policy received") +} + +func TestWorkerClassifyRequest(t *testing.T) { + root := findProjectRoot() + + ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") + workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") + + ctrl := exec.Command(ctrlBin) + ctrl.Start() + defer ctrl.Process.Kill() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + worker := exec.CommandContext(ctx, workerBin) + worker.Env = []string{ + "METRICS_GATEWAY_ADDRESS=localhost", + "METRICS_GATEWAY_PORT=9091", + "TEST_CLASSIFY_DOMAIN=facebook.com", + } + + output, err := worker.CombinedOutput() + assert.NoError(t, err, "Worker failed") + assert.Contains(t, string(output), "Domain 'facebook.com' classified as category") +} \ No newline at end of file diff --git a/worker/BUILD b/worker/BUILD index f53c957..047688c 100644 --- a/worker/BUILD +++ b/worker/BUILD @@ -1,28 +1,12 @@ -cc_binary( - name = "worker", - srcs = [ - "src/main.cpp", - "src/md_calculator.cpp", - "src/file.cpp", - "src/worker.cpp", - "src/metrics_collector.cpp", - "include/file.hpp", - "include/md_calculator.hpp", - "include/worker.hpp", - "include/metrics_collector.hpp", - ], - deps = ["@spdlog//:spdlog", - "@curl//:curl", - ":communication_cc_proto"], - includes = ["/usr/local/include"], - linkopts = [ - "-L/usr/local/openssl/lib", - "-lssl", - "-lcrypto", - "-L/usr/local/lib", - "-lprometheus-cpp-push", - "-lprometheus-cpp-core", - ], +load("@grpc//bazel:generate_cc.bzl", "generate_cc") + +proto_library( + name = "communication_proto", + srcs = ["communication.proto"], + deps = [ + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:struct_proto", + ], ) cc_proto_library( @@ -30,7 +14,62 @@ cc_proto_library( deps = [":communication_proto"], ) -proto_library( - name = "communication_proto", - srcs = [ "communication.proto" ] +generate_cc( + name = "communication_cc_grpc_gen", + srcs = [":communication_proto"], + plugin = "@grpc//src/compiler:grpc_cpp_plugin", + well_known_protos = True, + generate_mocks = True, +) + +cc_library( + name = "communication_cc_grpc", + srcs = [":communication_cc_grpc_gen"], + hdrs = [":communication_cc_grpc_gen"], + deps = [ + ":communication_cc_proto", + "@grpc//:grpc++", + ], +) + +cc_library( + name = "worker_headers", + hdrs = [ + "include/worker.hpp", + "include/md_calculator.hpp", + "include/file.hpp", + "include/metrics_collector.hpp", + ], + srcs = [], + visibility = ["//visibility:public"], +) + +cc_binary( + name = "worker", + srcs = [ + "src/main.cpp", + "src/md_calculator.cpp", + "src/file.cpp", + "src/worker.cpp", + "src/metrics_collector.cpp", + ], + deps = [ + ":worker_headers", + ":communication_cc_proto", + ":communication_cc_grpc", + "@grpc//:grpc++", + "@spdlog//:spdlog", + "@curl//:curl", + ], + copts = [ + "-I$(GENDIR)/..", + ], + linkopts = [ + "-L/usr/local/openssl/lib", + "-lssl", + "-lcrypto", + "-L/usr/local/lib", + "-lprometheus-cpp-push", + "-lprometheus-cpp-core", + ], ) diff --git a/worker/MODULE.bazel b/worker/MODULE.bazel index dc3d2aa..6f3f40d 100644 --- a/worker/MODULE.bazel +++ b/worker/MODULE.bazel @@ -1,10 +1,14 @@ module( - name = "worker", + name = "worker", + version = "0.1.0", ) bazel_dep(name = "rules_cc", version = "0.1.1") bazel_dep(name = "platforms", version = "0.0.11") bazel_dep(name = "spdlog", version = "1.15.2") bazel_dep(name = "abseil-cpp", version = "20250512.1") -bazel_dep(name = "protobuf", version = "32.0-rc1") bazel_dep(name = "curl", version = "8.8.0.bcr.3") +bazel_dep(name = "bazel_skylib", version = "1.8.1") +bazel_dep(name = "rules_proto", version = "7.0.2") +bazel_dep(name = "protobuf", version = "31.1", repo_name = "com_google_protobuf") +bazel_dep(name = "grpc", version = "1.78.0") diff --git a/worker/communication.proto b/worker/communication.proto index 1330534..185fc54 100644 --- a/worker/communication.proto +++ b/worker/communication.proto @@ -1,5 +1,13 @@ syntax = "proto3"; option go_package = "github.com/moevm/grpc_server/pkg/proto/communication"; +import "google/protobuf/struct.proto"; +import "google/protobuf/empty.proto"; + +service DataService { + rpc GetPolicy(GetPolicyRequest) returns (WorkerPolicy); + rpc Classify(ClassifyRequest) returns (ClassifyResponse); + rpc SendStats(StatsReport) returns (google.protobuf.Empty); +} enum PulseType { PULSE_INVALD = 0; @@ -57,3 +65,44 @@ message WorkerResponse { uint64 task_id = 2; uint64 extra_size = 3; } + +message GetPolicyRequest { + uint64 worker_id = 1; + uint64 policy_hash = 2; + uint64 config_version = 3; +} + +message WorkerPolicy { + repeated string block_categories = 1; + map block_by_trust = 2; + repeated string block_domains = 3; + repeated string allow_domains = 4; + int32 min_trust_level = 5; + uint64 config_version = 6; + uint64 policy_hash = 7; + google.protobuf.Struct extra = 8; +} + +message ClassifyRequest { + uint64 worker_id = 1; + string domain = 2; +} + +message ClassifyResponse { + repeated string categories = 1; + int32 trust_level = 2; +} + +message ResourceStats { + string domain = 1; + uint64 blocked = 2; + uint64 allowed = 3; +} + +message StatsReport { + uint64 worker_id = 1; + uint64 time = 2; + uint64 total_blocked = 3; + uint64 total_allowed = 4; + repeated ResourceStats resources = 5; +} diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index 2fede92..a727c9a 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -3,11 +3,18 @@ #include "communication.pb.h" +#include "communication.grpc.pb.h" #include +#include #include +#include + #define SOCKET_DIR "/run/controller/" #define MAIN_SOCKET_NAME "main.sock" +#define POLICY_SOCKET_NAME "policy.sock" +#define CLASSIFY_SOCKET_NAME "classify.sock" +#define STATS_SOCKET_NAME "stats.sock" #define EXPECTED_PULSE_TIME 60 #define MIN_PULSE_TIME 30 #define MAX_PULSE_TIME 45 @@ -35,12 +42,16 @@ class Worker { int listener_fd = -1; uint64_t worker_id = 0; uint64_t current_task_id = 0; + uint64_t current_policy_hash = 0; + uint64_t current_config_version = 0; std::chrono::time_point last_pulse_time; uint64_t pulse_interval = MIN_PULSE_TIME; std::string fetch_data; std::string extra_data; + std::unique_ptr stub_; + enum class InitResponse : uint64_t { OK = 1 }; WorkerState state; void LogStateChange(WorkerState new_state); @@ -98,7 +109,9 @@ class Worker { ~Worker(); inline uint64_t GetID() const { return worker_id; } - + void requestPolicyFromController(); + void classifyDomain(const std::string &domain); + void statsReport(); WorkerState GetState() const { return state; } virtual void ProcessTask(const std::vector &data) = 0; void MainLoop(); diff --git a/worker/src/main.cpp b/worker/src/main.cpp index 1f0b31d..8143472 100644 --- a/worker/src/main.cpp +++ b/worker/src/main.cpp @@ -39,7 +39,38 @@ int main() { gateway_port); try { - HashWorker(gateway_address, gateway_port).MainLoop(); + HashWorker worker(gateway_address, gateway_port); + + bool test_mode = false; + + if (getenv("TEST_REQUEST_POLICY") != nullptr) { + test_mode = true; + spdlog::info("Test mode: requesting policy"); + std::this_thread::sleep_for(std::chrono::seconds(2)); + worker.requestPolicyFromController(); + } + + if (getenv("TEST_STATS") != nullptr) { + test_mode = true; + spdlog::info("Test mode: send stats"); + std::this_thread::sleep_for(std::chrono::seconds(2)); + worker.statsReport(); + } + + if (const char *domain = getenv("TEST_CLASSIFY_DOMAIN")) { + test_mode = true; + spdlog::info("Test mode: classifying domain '{}'", domain); + std::this_thread::sleep_for(std::chrono::seconds(1)); + worker.classifyDomain(domain); + } + + if (test_mode) { + spdlog::info("Test mode completed, exiting"); + return 0; + } + + worker.MainLoop(); + } catch (WorkerException &e) { spdlog::error(e.what()); return 1; diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index a1a1f75..e596ee1 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -1,8 +1,10 @@ #include "../include/worker.hpp" +#include "communication.grpc.pb.h" #include #include #include +#include #include #include #include @@ -125,9 +127,95 @@ void Worker::SendPulse(PulseType type) { } } +void Worker::requestPolicyFromController() { + try { + spdlog::info("Worker {} requests policy", worker_id); + + GetPolicyRequest req; + req.set_worker_id(worker_id); + req.set_policy_hash(current_policy_hash); + req.set_config_version(current_config_version); + + WorkerPolicy policy; + grpc::ClientContext context; + + auto status = stub_->GetPolicy(&context, req, &policy); + if (!status.ok()) { + throw WorkerException("GetPolicy failed: " + status.error_message()); + } + + current_policy_hash = policy.policy_hash(); + spdlog::info("Policy received"); + + } catch (const std::exception &e) { + SetState(WorkerState::ERROR); + throw WorkerException(std::string("requestPolicyFromController: ") + + e.what()); + } +} + +void Worker::classifyDomain(const std::string &domain) { + try { + spdlog::info("Worker {} classifying domain '{}'", worker_id, domain); + + ClassifyRequest req; + req.set_worker_id(worker_id); + req.set_domain(domain); + + ClassifyResponse resp; + grpc::ClientContext context; + + auto status = stub_->Classify(&context, req, &resp); + if (!status.ok()) { + throw WorkerException("Classify failed: " + status.error_message()); + } + + std::string cat = + resp.categories_size() > 0 ? resp.categories(0) : "unknown"; + spdlog::info("Domain '{}' classified as category '{}' with trust level {}", + domain, cat, resp.trust_level()); + + } catch (const std::exception &e) { + SetState(WorkerState::ERROR); + throw WorkerException(std::string("classifyDomain: ") + e.what()); + } +} + +void Worker::statsReport() { + try { + spdlog::info("Worker {} send stats", worker_id); + + StatsReport report; + report.set_worker_id(worker_id); + report.set_time(time(nullptr)); + + grpc::ClientContext context; + google::protobuf::Empty response; + + auto status = stub_->SendStats(&context, report, &response); + if (!status.ok()) { + throw WorkerException("SendStats failed: " + status.error_message()); + } + + spdlog::info("Stats sent successfully"); + + } catch (const std::exception &e) { + spdlog::error("statsReport failed: {}", e.what()); + } +} + Worker::Worker() : listener_fd(-1), state(WorkerState::BOOTING) { SendPulse(PULSE_REGISTER); + std::string controller_addr = "localhost:50051"; + if (const char *env_addr = getenv("CONTROLLER_GRPC_ADDR")) { + controller_addr = env_addr; + } + auto channel = + grpc::CreateChannel(controller_addr, grpc::InsecureChannelCredentials()); + stub_ = DataService::NewStub(channel); + spdlog::info("gRPC channel created to {}", controller_addr); + socket_path = std::string(SOCKET_DIR) + std::to_string(worker_id) + ".sock"; unlink(socket_path.c_str()); @@ -157,6 +245,7 @@ Worker::Worker() : listener_fd(-1), state(WorkerState::BOOTING) { SendPulse(PULSE_OK); SetState(WorkerState::FREE); + requestPolicyFromController(); } Worker::~Worker() {