From edbf3344be7a94258ba74e7aab19c2e3e95ae3c3 Mon Sep 17 00:00:00 2001 From: stepanrodimanov Date: Thu, 19 Feb 2026 20:42:12 +0000 Subject: [PATCH 1/8] Add admin client and policy manager implementation --- admin/README.md | 26 ++++ admin/admin.py | 42 ++++++ admin/config.toml | 22 +++ admin/requirements.txt | 3 + controller/MODULE.bazel | 1 + controller/cmd/grpc_server/BUILD | 1 + controller/cmd/grpc_server/main.go | 9 +- controller/go.mod | 1 + controller/go.sum | 2 + controller/internal/grpcserver/BUILD | 11 +- .../internal/grpcserver/admin_server.go | 42 ++++++ controller/internal/manager/BUILD | 3 +- controller/internal/manager/manager.go | 25 +++- controller/internal/manager/policy_manager.go | 125 ++++++++++++++++++ controller/pkg/proto/admin_service/BUILD | 17 +++ .../proto/admin_service/admin_service.proto | 17 +++ .../proto/communication/communication.proto | 25 ++++ worker/communication.proto | 25 ++++ worker/include/worker.hpp | 2 +- worker/src/worker.cpp | 5 + 20 files changed, 395 insertions(+), 9 deletions(-) create mode 100644 admin/README.md create mode 100644 admin/admin.py create mode 100644 admin/config.toml create mode 100644 admin/requirements.txt create mode 100644 controller/internal/grpcserver/admin_server.go create mode 100644 controller/internal/manager/policy_manager.go create mode 100644 controller/pkg/proto/admin_service/BUILD create mode 100644 controller/pkg/proto/admin_service/admin_service.proto diff --git a/admin/README.md b/admin/README.md new file mode 100644 index 0000000..635c7da --- /dev/null +++ b/admin/README.md @@ -0,0 +1,26 @@ +# Admin Client + +## Установка зависимостей +```bash +pip install -r requirements.txt +``` + +## Генерация proto-файлов +```bash +python -m grpc_tools.protoc \ + --python_out=. \ + --grpc_python_out=. \ + -I ../controller/pkg/proto/admin_service \ + ../controller/pkg/proto/admin_service/admin_service.proto +``` + +## Запуск контроллера (в отдельном терминале) +```bash +cd ../controller +bazel run //cmd/grpc_server:grpc_server +``` + +## Запуск клиента +```bash +python admin.py --file config.toml +``` \ No newline at end of file diff --git a/admin/admin.py b/admin/admin.py new file mode 100644 index 0000000..d917f26 --- /dev/null +++ b/admin/admin.py @@ -0,0 +1,42 @@ +import argparse +import grpc +from dotenv import load_dotenv +import os + +import admin_service_pb2 +import admin_service_pb2_grpc + +load_dotenv(os.path.join(os.path.dirname(__file__), "..", "controller", ".env")) + +class AdminClient: + def __init__(self): + self.host: str = os.environ["SERVER_HOST"] + self.port: str = os.environ["SERVER_PORT"] + self.channel: grpc.channel = grpc.insecure_channel(f"{self.host}:{self.port}") + self.stub: admin_service_pb2_grpc.AdminServiceStub = admin_service_pb2_grpc.AdminServiceStub(self.channel) + + def load_config(self, toml_file: str) -> admin_service_pb2.LoadConfigResponse: + with open(toml_file, "rb") as f: + content: bytes = f.read() + + request: admin_service_pb2.LoadConfigRequest = (admin_service_pb2.LoadConfigRequest(config_data=content)) + return self.stub.LoadConfig(request) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--file", default="config.toml") + args = parser.parse_args() + + try: + client = AdminClient() + response = client.load_config(args.file) + print("Config loaded") + except Exception as e: + print(f"Error loading config: {e}") + return 1 + + return 0 + +if __name__ == "__main__": + main() + diff --git a/admin/config.toml b/admin/config.toml new file mode 100644 index 0000000..3964b87 --- /dev/null +++ b/admin/config.toml @@ -0,0 +1,22 @@ +[global.rules] +block_categories = ["MALWARE", "SOCIAL"] +block_domains = ["youtube.com", "tiktok.com"] +allow_domains = ["github.com", "stackoverflow.com"] + +[global.rules.block_by_trust] +ENTERTAINMENT = 6 +NEWS = 4 + +[filters.filter_1] +block_categories = ["SOCIAL", "ENTERTAINMENT"] +block_domains = ["instagram.com"] +allow_domains = ["vk.com"] + +[filters.filter_1.block_by_trust] +SOCIAL = 8 +ENTERTAINMENT = 7 + +[filters.filter_2] +block_categories = ["MALWARE"] +allow_domains = ["github.com", "gitlab.com"] +min_trust_level = 3 \ No newline at end of file diff --git a/admin/requirements.txt b/admin/requirements.txt new file mode 100644 index 0000000..a4fc5b2 --- /dev/null +++ b/admin/requirements.txt @@ -0,0 +1,3 @@ +grpcio==1.60.0 +protobuf==4.25.1 +python-dotenv==1.0.0 \ No newline at end of file diff --git a/controller/MODULE.bazel b/controller/MODULE.bazel index 2efd7fc..06b9ce2 100644 --- a/controller/MODULE.bazel +++ b/controller/MODULE.bazel @@ -25,6 +25,7 @@ use_repo( go_deps, "com_github_joho_godotenv", "com_github_stretchr_testify", + "com_github_pelletier_go_toml", "org_golang_google_grpc", "org_golang_google_protobuf", "org_golang_x_text", diff --git a/controller/cmd/grpc_server/BUILD b/controller/cmd/grpc_server/BUILD index 2c14b95..7aae62a 100644 --- a/controller/cmd/grpc_server/BUILD +++ b/controller/cmd/grpc_server/BUILD @@ -15,6 +15,7 @@ go_library( "//internal/config", "//internal/grpcserver", "//internal/manager", + "//pkg/proto/admin_service:admin_service_go_proto", "//pkg/proto/file_service:file_service_go_proto", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//reflection", diff --git a/controller/cmd/grpc_server/main.go b/controller/cmd/grpc_server/main.go index 7b5f99f..f6653c6 100644 --- a/controller/cmd/grpc_server/main.go +++ b/controller/cmd/grpc_server/main.go @@ -10,16 +10,20 @@ import ( pb "github.com/moevm/grpc_server/pkg/proto/file_service" "google.golang.org/grpc" "google.golang.org/grpc/reflection" + adminPb "github.com/moevm/grpc_server/pkg/proto/admin_service" ) func main() { cfg := config.Load() - - mgr, err := manager.NewManager() + adminServer := grpcserver.NewAdminServer() + configData, configVersion := adminServer.GetConfig() + mgr, err := manager.NewManager(configData, configVersion) if err != nil { log.Fatalf("manager.NewManager(): %v", err) } + adminServer.SetManager(mgr) + lis, err := net.Listen("tcp", net.JoinHostPort(cfg.Host, cfg.Port)) if err != nil { log.Fatalf("failed to listen: %v", err) @@ -31,6 +35,7 @@ func main() { } service := grpc.NewServer(serverOpts...) + adminPb.RegisterAdminServiceServer(service, adminServer) pb.RegisterFileServiceServer(service, grpcserver.NewServer(cfg.AllowedChars, mgr)) reflection.Register(service) diff --git a/controller/go.mod b/controller/go.mod index e2a0fff..306b484 100644 --- a/controller/go.mod +++ b/controller/go.mod @@ -11,6 +11,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pelletier/go-toml v1.9.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/net v0.40.0 // indirect golang.org/x/sys v0.33.0 // indirect diff --git a/controller/go.sum b/controller/go.sum index 3dd6f80..9f76e60 100644 --- a/controller/go.sum +++ b/controller/go.sum @@ -12,6 +12,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= +github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/controller/internal/grpcserver/BUILD b/controller/internal/grpcserver/BUILD index ffc12e3..78ba59f 100644 --- a/controller/internal/grpcserver/BUILD +++ b/controller/internal/grpcserver/BUILD @@ -2,21 +2,26 @@ load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "grpcserver", - srcs = ["server.go"], + srcs = [ + "server.go", + "admin_server.go", + ], importpath = "github.com/moevm/grpc_server/internal/grpcserver", visibility = ["//visibility:public"], deps = [ "//internal/manager", "//pkg/proto/file_service:file_service_go_proto", + "//pkg/proto/admin_service:admin_service_go_proto", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", - ], + "@org_golang_google_grpc//:go_default_library", + ] ) go_test( name = "grpcserver_test", srcs = ["server_test.go"], - embed = [ ":grpcserver" ], + embed = [":grpcserver"], deps = [ "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/controller/internal/grpcserver/admin_server.go b/controller/internal/grpcserver/admin_server.go new file mode 100644 index 0000000..88a8e82 --- /dev/null +++ b/controller/internal/grpcserver/admin_server.go @@ -0,0 +1,42 @@ +package grpcserver + +import ( + "context" + "github.com/moevm/grpc_server/internal/manager" + pb "github.com/moevm/grpc_server/pkg/proto/admin_service" +) + +type AdminServer struct { + pb.UnimplementedAdminServiceServer + configData []byte + version uint64 + manager *manager.Manager +} + +func NewAdminServer() *AdminServer { + return &AdminServer{ + version: 0, + } +} + +func (s *AdminServer) SetManager(m *manager.Manager) { + s.manager = m +} + +func (s *AdminServer) LoadConfig(ctx context.Context, req *pb.LoadConfigRequest) (*pb.LoadConfigResponse, error) { + s.configData = req.ConfigData + s.version++ + + if s.manager != nil { + s.manager.UpdateConfig(s.configData, s.version) + } + + return &pb.LoadConfigResponse{ + Success: true, + }, nil +} + + +func (s *AdminServer) GetConfig() ([]byte, uint64) { + return s.configData, s.version +} diff --git a/controller/internal/manager/BUILD b/controller/internal/manager/BUILD index fc18761..d2ce857 100644 --- a/controller/internal/manager/BUILD +++ b/controller/internal/manager/BUILD @@ -2,7 +2,7 @@ load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "manager", - srcs = ["manager.go"], + srcs = ["manager.go", "policy_manager.go",], importpath = "github.com/moevm/grpc_server/internal/manager", visibility = ["//visibility:public"], deps = [ @@ -11,5 +11,6 @@ go_library( "//pkg/proto/communication:communication_go_proto", "@org_golang_google_protobuf//proto", "@com_github_joho_godotenv//:go_default_library", + "@com_github_pelletier_go_toml//:go_default_library", ], ) diff --git a/controller/internal/manager/manager.go b/controller/internal/manager/manager.go index 92b635b..ff2386b 100644 --- a/controller/internal/manager/manager.go +++ b/controller/internal/manager/manager.go @@ -23,7 +23,6 @@ import ( "github.com/moevm/grpc_server/internal/conn" "google.golang.org/protobuf/proto" - communication "github.com/moevm/grpc_server/pkg/proto/communication" ) @@ -49,6 +48,7 @@ type IManager interface { type Manager struct { listener net.Listener + policyManager *PolicyManager workers map[uint64]*Worker workersMutex sync.Mutex workerId uint64 @@ -504,7 +504,27 @@ func removeContents(dir string) error { return nil } -func NewManager() (*Manager, error) { +func (m *Manager) HandleGetPolicy(workerID uint64) ([]byte, error) { + log.Printf("Worker %d requested policy", workerID) + + policyProto := m.policyManager.GetWorkerPolicyProto(workerID) + + policyBytes, err := proto.Marshal(policyProto) + if err != nil { + return nil, err + } + + log.Printf("Sending policy to worker %d", workerID) + return policyBytes, nil +} + +func (m *Manager) UpdateConfig(configData []byte, version uint64) { + if m.policyManager != nil { + m.policyManager.UpdateConfig(configData, version) + } +} + +func NewManager(configData []byte, version uint64) (*Manager, error) { if err := removeContents(workerSocketPath); err != nil { return nil, fmt.Errorf("failed to clean socket directory: %w", err) } @@ -516,6 +536,7 @@ func NewManager() (*Manager, error) { m := &Manager{ listener: listener, + policyManager: NewPolicyManager(configData, version), workers: make(map[uint64]*Worker), tasks: make(map[uint64]*Task), freeWorkers: make(chan uint64, 32), diff --git a/controller/internal/manager/policy_manager.go b/controller/internal/manager/policy_manager.go new file mode 100644 index 0000000..dcea20b --- /dev/null +++ b/controller/internal/manager/policy_manager.go @@ -0,0 +1,125 @@ +package manager + +import ( + "fmt" + "log" + + "github.com/pelletier/go-toml" + pb "github.com/moevm/grpc_server/pkg/proto/communication" +) + +type TOMLRules struct { + BlockCategories []string `toml:"block_categories"` + BlockByTrust map[string]int32 `toml:"block_by_trust"` + BlockDomains []string `toml:"block_domains"` + AllowDomains []string `toml:"allow_domains"` + MinTrustLevel int32 `toml:"min_trust_level"` +} + +type TOMLConfig struct { + Global struct { + Rules TOMLRules `toml:"rules"` + } `toml:"global"` + Filters map[string]TOMLRules `toml:"filters"` +} + +type PolicyManager struct { + config *TOMLConfig + version uint64 +} + +func NewPolicyManager(configData []byte, version uint64) *PolicyManager { + pm := &PolicyManager{ + version: version, + config: &TOMLConfig{}, + } + if len(configData) > 0 { + var cfg TOMLConfig + if err := toml.Unmarshal(configData, &cfg); err != nil { + log.Printf("Failed to parse TOML: %v", err) + } else { + pm.config = &cfg + } + } + return pm +} + +func (pm *PolicyManager) GetWorkerPolicyProto(workerID uint64) *pb.WorkerPolicy { + if pm.config == nil { + return &pb.WorkerPolicy{Version: pm.version} + } + + policy := &pb.WorkerPolicy{ + BlockCategories: make([]string, len(pm.config.Global.Rules.BlockCategories)), + BlockByTrust: make(map[string]int32, len(pm.config.Global.Rules.BlockByTrust)), + BlockDomains: make([]string, len(pm.config.Global.Rules.BlockDomains)), + AllowDomains: make([]string, len(pm.config.Global.Rules.AllowDomains)), + MinTrustLevel: pm.config.Global.Rules.MinTrustLevel, + Version: pm.version, + } + copy(policy.BlockCategories, pm.config.Global.Rules.BlockCategories) + for k, v := range pm.config.Global.Rules.BlockByTrust { + policy.BlockByTrust[k] = v + } + copy(policy.BlockDomains, pm.config.Global.Rules.BlockDomains) + copy(policy.AllowDomains, pm.config.Global.Rules.AllowDomains) + + filterName := fmt.Sprintf("filter_%d", workerID) + if filter, ok := pm.config.Filters[filterName]; ok { + if len(filter.BlockCategories) > 0 { + existing := make(map[string]bool) + for _, cat := range policy.BlockCategories { + existing[cat] = true + } + for _, cat := range filter.BlockCategories { + if !existing[cat] { + policy.BlockCategories = append(policy.BlockCategories, cat) + } + } + } + + for k, v := range filter.BlockByTrust { + policy.BlockByTrust[k] = v + } + + if len(filter.BlockDomains) > 0 { + existing := make(map[string]bool) + for _, d := range policy.BlockDomains { + existing[d] = true + } + for _, d := range filter.BlockDomains { + if !existing[d] { + policy.BlockDomains = append(policy.BlockDomains, d) + } + } + } + + if len(filter.AllowDomains) > 0 { + existing := make(map[string]bool) + for _, d := range policy.AllowDomains { + existing[d] = true + } + for _, d := range filter.AllowDomains { + if !existing[d] { + policy.AllowDomains = append(policy.AllowDomains, d) + } + } + } + if filter.MinTrustLevel != 0 { + policy.MinTrustLevel = filter.MinTrustLevel + } + } + + return policy +} + +func (pm *PolicyManager) UpdateConfig(configData []byte, version uint64) { + var cfg TOMLConfig + if err := toml.Unmarshal(configData, &cfg); err != nil { + log.Printf("Failed to parse TOML in UpdateConfig: %v", err) + return + } + pm.config = &cfg + pm.version = version + log.Printf("Config updated to version %d", version) +} diff --git a/controller/pkg/proto/admin_service/BUILD b/controller/pkg/proto/admin_service/BUILD new file mode 100644 index 0000000..8ee5046 --- /dev/null +++ b/controller/pkg/proto/admin_service/BUILD @@ -0,0 +1,17 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_go//proto:def.bzl", "go_proto_library") + +proto_library( + name = "admin_service_proto", + srcs = ["admin_service.proto"], + visibility = ["//visibility:public"], +) + + +go_proto_library( + name = "admin_service_go_proto", + compilers = ["@rules_go//proto:go_grpc"], + importpath = "github.com/moevm/grpc_server/pkg/proto/admin_service", + proto = ":admin_service_proto", + visibility = ["//visibility:public"], +) diff --git a/controller/pkg/proto/admin_service/admin_service.proto b/controller/pkg/proto/admin_service/admin_service.proto new file mode 100644 index 0000000..c587721 --- /dev/null +++ b/controller/pkg/proto/admin_service/admin_service.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package admin_service; + +option go_package = "github.com/moevm/grpc_server/pkg/proto/admin_service"; + +service AdminService { + rpc LoadConfig(LoadConfigRequest) returns (LoadConfigResponse) {} +} + +message LoadConfigRequest { + bytes config_data = 1; +} + +message LoadConfigResponse { + bool success = 1; +} diff --git a/controller/pkg/proto/communication/communication.proto b/controller/pkg/proto/communication/communication.proto index 1330534..88521e9 100644 --- a/controller/pkg/proto/communication/communication.proto +++ b/controller/pkg/proto/communication/communication.proto @@ -35,6 +35,8 @@ enum ControlType { CTRL_FETCH = 2; CTRL_SET_TASK = 3; CTRL_GET_STATUS = 4; + CTRL_GET_POLICY = 5; + CTRL_CLASSIFY = 6; } message ControlMsg { @@ -57,3 +59,26 @@ message WorkerResponse { uint64 task_id = 2; uint64 extra_size = 3; } + +message GetPolicyRequest { + uint64 worker_id = 1; +} + +message WorkerPolicy { + repeated string block_categories = 1; + map block_by_trust = 2; + repeated string block_domains = 3; + repeated string allow_domains = 4; + int32 min_trust_level = 5; + uint64 version = 6; +} + +message ClassifyRequest { + uint64 worker_id = 1; + string domain = 2; +} + +message ClassifyResponse { + string category = 1; + int32 trust_level = 2; +} diff --git a/worker/communication.proto b/worker/communication.proto index 1330534..88521e9 100644 --- a/worker/communication.proto +++ b/worker/communication.proto @@ -35,6 +35,8 @@ enum ControlType { CTRL_FETCH = 2; CTRL_SET_TASK = 3; CTRL_GET_STATUS = 4; + CTRL_GET_POLICY = 5; + CTRL_CLASSIFY = 6; } message ControlMsg { @@ -57,3 +59,26 @@ message WorkerResponse { uint64 task_id = 2; uint64 extra_size = 3; } + +message GetPolicyRequest { + uint64 worker_id = 1; +} + +message WorkerPolicy { + repeated string block_categories = 1; + map block_by_trust = 2; + repeated string block_domains = 3; + repeated string allow_domains = 4; + int32 min_trust_level = 5; + uint64 version = 6; +} + +message ClassifyRequest { + uint64 worker_id = 1; + string domain = 2; +} + +message ClassifyResponse { + string category = 1; + int32 trust_level = 2; +} diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index 2fede92..f8d99fd 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -98,7 +98,7 @@ class Worker { ~Worker(); inline uint64_t GetID() const { return worker_id; } - + void requestPolicyFromController(); WorkerState GetState() const { return state; } virtual void ProcessTask(const std::vector &data) = 0; void MainLoop(); diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index a1a1f75..c47fe99 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -157,6 +157,7 @@ Worker::Worker() : listener_fd(-1), state(WorkerState::BOOTING) { SendPulse(PULSE_OK); SetState(WorkerState::FREE); + requestPolicyFromController(); } Worker::~Worker() { @@ -214,6 +215,10 @@ void Worker::HandleSetTaskControlMessage(const ControlMsg &msg, std::thread(ProcessTask_Static, this, extra).detach(); } +void Worker::requestPolicyFromController() { + spdlog::info("Worker {} requesting policy from controller.", worker_id); +} + void Worker::HandleGetStatusControlMessage(WorkerResponse &resp) {} void Worker::HandleControlMessage(int client_fd) { From 4c6d85281a9886527c2452b5cc2e9b2c121457c2 Mon Sep 17 00:00:00 2001 From: stepanrodimanov Date: Sat, 21 Feb 2026 21:11:58 +0000 Subject: [PATCH 2/8] update --- admin/admin.py | 2 +- admin/requirements.txt | 1 + controller/internal/manager/manager.go | 112 ++++++++++++++++++ .../proto/communication/communication.proto | 4 +- worker/communication.proto | 4 +- worker/include/worker.hpp | 3 + worker/src/worker.cpp | 73 +++++++++++- 7 files changed, 188 insertions(+), 11 deletions(-) diff --git a/admin/admin.py b/admin/admin.py index d917f26..613afc3 100644 --- a/admin/admin.py +++ b/admin/admin.py @@ -29,7 +29,7 @@ def main(): try: client = AdminClient() - response = client.load_config(args.file) + client.load_config(args.file) print("Config loaded") except Exception as e: print(f"Error loading config: {e}") diff --git a/admin/requirements.txt b/admin/requirements.txt index a4fc5b2..3bf5f6e 100644 --- a/admin/requirements.txt +++ b/admin/requirements.txt @@ -1,3 +1,4 @@ grpcio==1.60.0 +grpcio-tools==1.60.0 protobuf==4.25.1 python-dotenv==1.0.0 \ No newline at end of file diff --git a/controller/internal/manager/manager.go b/controller/internal/manager/manager.go index ff2386b..d884977 100644 --- a/controller/internal/manager/manager.go +++ b/controller/internal/manager/manager.go @@ -37,6 +37,8 @@ const ( const ( workerMainSocketPath = "/run/controller/main.sock" + workerPolicySocketPath = "/run/controller/policy.sock" + workerClassifySocketPath = "/run/controller/classify.sock" workerSocketPath = "/run/controller/" ) @@ -47,6 +49,8 @@ type IManager interface { type Manager struct { listener net.Listener + policyListener net.Listener + classifyListener net.Listener policyManager *PolicyManager workers map[uint64]*Worker @@ -328,6 +332,100 @@ func (m *Manager) mainLoop() { } } + +func (m *Manager) policyLoop() { + log.Print("Listening on policy.sock") + for { + select { + case <-m.shutdown: + return + default: + netConn, err := m.policyListener.Accept() + if err != nil { + log.Printf("Policy accept error: %v", err) + continue + } + go m.handlePolicyConnection(conn.Unix{Conn: netConn}) + } + } +} + +func (m *Manager) classifyLoop() { + log.Print("Listening on classify.sock") + for { + select { + case <-m.shutdown: + return + default: + netConn, err := m.classifyListener.Accept() + if err != nil { + log.Printf("Classify accept error: %v", err) + continue + } + go m.handleClassifyConnection(conn.Unix{Conn: netConn}) + } + } +} + +func (m *Manager) handlePolicyConnection(conn conn.Unix) { + defer conn.Close() + + msgData, err := conn.ReadMessage() + if err != nil { + m.errorChan <- fmt.Errorf("policy read error: %w", err) + return + } + + var req communication.GetPolicyRequest + if err := proto.Unmarshal(msgData, &req); err != nil { + m.errorChan <- fmt.Errorf("policy unmarshal error: %w", err) + return + } + + log.Printf("Policy request received from worker %d", req.WorkerId) + + policyBytes, err := m.HandleGetPolicy(req.WorkerId) + if err != nil { + log.Printf("Error getting policy: %v", err) + return + } + + conn.WriteMessage(policyBytes) + log.Printf("Policy sent worker %d", req.WorkerId) +} + +func (m *Manager) handleClassifyConnection(conn conn.Unix) { + defer conn.Close() + + msgData, err := conn.ReadMessage() + if err != nil { + m.errorChan <- fmt.Errorf("classify read error: %w", err) + return + } + + var req communication.ClassifyRequest + if err := proto.Unmarshal(msgData, &req); err != nil { + m.errorChan <- fmt.Errorf("classify unmarshal error: %w", err) + return + } + + log.Printf("Classify request received from worker %d for domen '%s'", + req.WorkerId, req.Domen) + + resp := &communication.ClassifyResponse{ + Category: "unknown", + TrustLevel: 50, + } + + respData, err := proto.Marshal(resp) + if err != nil { + m.errorChan <- fmt.Errorf("marshal classify response error: %w", err) + return + } + + conn.WriteMessage(respData) + log.Printf("Classify response sent to worker %d", req.WorkerId) +} // errorHandler catches the errors from goroutines and logs them. // This function should always be run in a goroutine. func (m *Manager) errorHandler() { @@ -534,8 +632,20 @@ func NewManager(configData []byte, version uint64) (*Manager, error) { return nil, fmt.Errorf("failed to create listener: %w", err) } + policyListener, err := net.Listen("unix", workerPolicySocketPath) + if err != nil { + return nil, fmt.Errorf("failed to create policy listener: %w", err) + } + + classifyListener, err := net.Listen("unix", workerClassifySocketPath) + if err != nil { + return nil, fmt.Errorf("failed to create classify listener: %w", err) + } + m := &Manager{ listener: listener, + policyListener: policyListener, + classifyListener: classifyListener, policyManager: NewPolicyManager(configData, version), workers: make(map[uint64]*Worker), tasks: make(map[uint64]*Task), @@ -548,6 +658,8 @@ func NewManager(configData []byte, version uint64) (*Manager, error) { } go m.mainLoop() + go m.policyLoop() + go m.classifyLoop() go m.errorHandler() go m.dispatchTasks() go m.checkHealth() diff --git a/controller/pkg/proto/communication/communication.proto b/controller/pkg/proto/communication/communication.proto index 88521e9..2f2add4 100644 --- a/controller/pkg/proto/communication/communication.proto +++ b/controller/pkg/proto/communication/communication.proto @@ -35,8 +35,6 @@ enum ControlType { CTRL_FETCH = 2; CTRL_SET_TASK = 3; CTRL_GET_STATUS = 4; - CTRL_GET_POLICY = 5; - CTRL_CLASSIFY = 6; } message ControlMsg { @@ -75,7 +73,7 @@ message WorkerPolicy { message ClassifyRequest { uint64 worker_id = 1; - string domain = 2; + string domen = 2; } message ClassifyResponse { diff --git a/worker/communication.proto b/worker/communication.proto index 88521e9..2f2add4 100644 --- a/worker/communication.proto +++ b/worker/communication.proto @@ -35,8 +35,6 @@ enum ControlType { CTRL_FETCH = 2; CTRL_SET_TASK = 3; CTRL_GET_STATUS = 4; - CTRL_GET_POLICY = 5; - CTRL_CLASSIFY = 6; } message ControlMsg { @@ -75,7 +73,7 @@ message WorkerPolicy { message ClassifyRequest { uint64 worker_id = 1; - string domain = 2; + string domen = 2; } message ClassifyResponse { diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index f8d99fd..645a577 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -8,6 +8,8 @@ #define SOCKET_DIR "/run/controller/" #define MAIN_SOCKET_NAME "main.sock" +#define POLICY_SOCKET_NAME "policy.sock" +#define CLASSIFY_SOCKET_NAME "classify.sock" #define EXPECTED_PULSE_TIME 60 #define MIN_PULSE_TIME 30 #define MAX_PULSE_TIME 45 @@ -99,6 +101,7 @@ class Worker { inline uint64_t GetID() const { return worker_id; } void requestPolicyFromController(); + void classifyDomen(const std::string& domen); WorkerState GetState() const { return state; } virtual void ProcessTask(const std::vector &data) = 0; void MainLoop(); diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index c47fe99..06b90e4 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -125,6 +125,75 @@ void Worker::SendPulse(PulseType type) { } } +void Worker::requestPolicyFromController() { + int main_fd = 0; + try { + spdlog::info("Worker {} requests policy", worker_id); + + GetPolicyRequest req; + req.set_worker_id(worker_id); + + main_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (main_fd < 0) + throw WorkerException(std::string("socket: ") + strerror(errno)); + + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, SOCKET_DIR POLICY_SOCKET_NAME, + sizeof(addr.sun_path) - 1); + + if (connect(main_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) + throw WorkerException(std::string("connect: ") + strerror(errno)); + WriteProtoMessage(main_fd, req); + WorkerPolicy policy; + ReadProtoMessage(main_fd, policy); + + spdlog::info("Policy received", policy.ShortDebugString()); + + close(main_fd); + } catch (const std::exception &e) { + close(main_fd); + SetState(WorkerState::ERROR); + spdlog::error("requestPolicyFromController failed: {}", e.what()); + throw WorkerException(std::string("requestPolicyFromController: ") + e.what()); + } +} + +void Worker::classifyDomen(const std::string& domen){ + int main_fd = 0; + try{ + spdlog::info("Worker {} classifying domen '{}'", worker_id, domen); + ClassifyRequest req; + req.set_worker_id(worker_id); + req.set_domen(domen); + + main_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (main_fd < 0) + throw WorkerException(std::string("socket: ") + strerror(errno)); + + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, SOCKET_DIR CLASSIFY_SOCKET_NAME, + sizeof(addr.sun_path) - 1); + + if (connect(main_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) + throw WorkerException(std::string("connect: ") + strerror(errno)); + + WriteProtoMessage(main_fd, req); + ClassifyResponse resp; + ReadProtoMessage(main_fd, resp); + + spdlog::info("Domen '{}' classified as category '{}' with trust level {}", domen, resp.category(), resp.trust_level()); + + close(main_fd); + + } catch (const std::exception &e) { + close(main_fd); + SetState(WorkerState::ERROR); + throw WorkerException(std::string("classifyDomain: ") + e.what()); + } +} + Worker::Worker() : listener_fd(-1), state(WorkerState::BOOTING) { SendPulse(PULSE_REGISTER); @@ -215,10 +284,6 @@ void Worker::HandleSetTaskControlMessage(const ControlMsg &msg, std::thread(ProcessTask_Static, this, extra).detach(); } -void Worker::requestPolicyFromController() { - spdlog::info("Worker {} requesting policy from controller.", worker_id); -} - void Worker::HandleGetStatusControlMessage(WorkerResponse &resp) {} void Worker::HandleControlMessage(int client_fd) { From 94e0aeea83da947bf1499e02a07c12c9827dd876 Mon Sep 17 00:00:00 2001 From: stepanrodimanov Date: Thu, 26 Feb 2026 20:02:21 +0300 Subject: [PATCH 3/8] update 2 --- controller/cmd/grpc_server/main.go | 3 +- .../internal/grpcserver/admin_server.go | 13 +--- controller/internal/manager/BUILD | 2 +- controller/internal/manager/manager.go | 39 +++++++--- controller/internal/manager/policy_manager.go | 56 ++++++++------ controller/pkg/proto/communication/BUILD | 3 + .../proto/communication/communication.proto | 11 ++- controller/test/BUILD | 11 +++ controller/test/README.md | 18 +++++ controller/test/run.sh | 13 ++++ controller/test/worker_test.go | 76 +++++++++++++++++++ worker/BUILD | 5 +- worker/communication.proto | 11 ++- worker/include/worker.hpp | 4 +- worker/src/main.cpp | 26 ++++++- worker/src/worker.cpp | 12 +-- 16 files changed, 243 insertions(+), 60 deletions(-) create mode 100644 controller/test/BUILD create mode 100644 controller/test/README.md create mode 100755 controller/test/run.sh create mode 100644 controller/test/worker_test.go diff --git a/controller/cmd/grpc_server/main.go b/controller/cmd/grpc_server/main.go index f6653c6..96e05a2 100644 --- a/controller/cmd/grpc_server/main.go +++ b/controller/cmd/grpc_server/main.go @@ -16,8 +16,7 @@ import ( func main() { cfg := config.Load() adminServer := grpcserver.NewAdminServer() - configData, configVersion := adminServer.GetConfig() - mgr, err := manager.NewManager(configData, configVersion) + mgr, err := manager.NewManager() if err != nil { log.Fatalf("manager.NewManager(): %v", err) } diff --git a/controller/internal/grpcserver/admin_server.go b/controller/internal/grpcserver/admin_server.go index 88a8e82..386bb85 100644 --- a/controller/internal/grpcserver/admin_server.go +++ b/controller/internal/grpcserver/admin_server.go @@ -9,14 +9,11 @@ import ( type AdminServer struct { pb.UnimplementedAdminServiceServer configData []byte - version uint64 manager *manager.Manager } func NewAdminServer() *AdminServer { - return &AdminServer{ - version: 0, - } + return &AdminServer{} } func (s *AdminServer) SetManager(m *manager.Manager) { @@ -25,10 +22,9 @@ func (s *AdminServer) SetManager(m *manager.Manager) { func (s *AdminServer) LoadConfig(ctx context.Context, req *pb.LoadConfigRequest) (*pb.LoadConfigResponse, error) { s.configData = req.ConfigData - s.version++ if s.manager != nil { - s.manager.UpdateConfig(s.configData, s.version) + s.manager.UpdateConfig(s.configData) } return &pb.LoadConfigResponse{ @@ -36,7 +32,6 @@ func (s *AdminServer) LoadConfig(ctx context.Context, req *pb.LoadConfigRequest) }, nil } - -func (s *AdminServer) GetConfig() ([]byte, uint64) { - return s.configData, s.version +func (s *AdminServer) GetConfig() []byte { + return s.configData } diff --git a/controller/internal/manager/BUILD b/controller/internal/manager/BUILD index d2ce857..91f216a 100644 --- a/controller/internal/manager/BUILD +++ b/controller/internal/manager/BUILD @@ -7,10 +7,10 @@ go_library( visibility = ["//visibility:public"], deps = [ "//internal/conn", - "//pkg/converter", "//pkg/proto/communication:communication_go_proto", "@org_golang_google_protobuf//proto", "@com_github_joho_godotenv//:go_default_library", "@com_github_pelletier_go_toml//:go_default_library", + "@org_golang_google_protobuf//types/known/structpb", ], ) diff --git a/controller/internal/manager/manager.go b/controller/internal/manager/manager.go index d884977..28c0ac2 100644 --- a/controller/internal/manager/manager.go +++ b/controller/internal/manager/manager.go @@ -384,14 +384,24 @@ func (m *Manager) handlePolicyConnection(conn conn.Unix) { log.Printf("Policy request received from worker %d", req.WorkerId) - policyBytes, err := m.HandleGetPolicy(req.WorkerId) + policyBytes, needUpdate, err := m.HandleGetPolicy(req.WorkerId, req.PolicyHash) if err != nil { log.Printf("Error getting policy: %v", err) return } + if !needUpdate { + log.Printf("Worker %d already has latest policy", req.WorkerId) + resp := &communication.WorkerResponse{ + Error: communication.WorkerError_WORKER_ERR_OK, + } + respBytes, _ := proto.Marshal(resp) + conn.WriteMessage(respBytes) + return + } + conn.WriteMessage(policyBytes) - log.Printf("Policy sent worker %d", req.WorkerId) + log.Printf("Policy sent to worker %d", req.WorkerId) } func (m *Manager) handleClassifyConnection(conn conn.Unix) { @@ -410,10 +420,10 @@ func (m *Manager) handleClassifyConnection(conn conn.Unix) { } log.Printf("Classify request received from worker %d for domen '%s'", - req.WorkerId, req.Domen) + req.WorkerId, req.Domain) resp := &communication.ClassifyResponse{ - Category: "unknown", + Categories: []string{"unknown"}, TrustLevel: 50, } @@ -602,27 +612,30 @@ func removeContents(dir string) error { return nil } -func (m *Manager) HandleGetPolicy(workerID uint64) ([]byte, error) { +func (m *Manager) HandleGetPolicy(workerID uint64, currentHash uint64) ([]byte, bool, error) { log.Printf("Worker %d requested policy", workerID) policyProto := m.policyManager.GetWorkerPolicyProto(workerID) + if currentHash == policyProto.PolicyHash { + return nil, false, nil + } + policyBytes, err := proto.Marshal(policyProto) if err != nil { - return nil, err + return nil, false, err } - log.Printf("Sending policy to worker %d", workerID) - return policyBytes, nil + return policyBytes, true, nil } -func (m *Manager) UpdateConfig(configData []byte, version uint64) { +func (m *Manager) UpdateConfig(configData []byte) { if m.policyManager != nil { - m.policyManager.UpdateConfig(configData, version) + m.policyManager.UpdateConfig(configData) } } -func NewManager(configData []byte, version uint64) (*Manager, error) { +func NewManager() (*Manager, error) { if err := removeContents(workerSocketPath); err != nil { return nil, fmt.Errorf("failed to clean socket directory: %w", err) } @@ -646,7 +659,7 @@ func NewManager(configData []byte, version uint64) (*Manager, error) { listener: listener, policyListener: policyListener, classifyListener: classifyListener, - policyManager: NewPolicyManager(configData, version), + policyManager: NewPolicyManager(), workers: make(map[uint64]*Worker), tasks: make(map[uint64]*Task), freeWorkers: make(chan uint64, 32), @@ -672,6 +685,8 @@ func (m *Manager) Shutdown() { m.shutdownOnce.Do(func() { close(m.shutdown) m.listener.Close() + m.policyListener.Close() + m.classifyListener.Close() close(m.freeWorkers) close(m.fetchWorkers) diff --git a/controller/internal/manager/policy_manager.go b/controller/internal/manager/policy_manager.go index dcea20b..efbaf3e 100644 --- a/controller/internal/manager/policy_manager.go +++ b/controller/internal/manager/policy_manager.go @@ -1,9 +1,12 @@ package manager import ( + "crypto/sha256" + "encoding/binary" "fmt" "log" - + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" "github.com/pelletier/go-toml" pb "github.com/moevm/grpc_server/pkg/proto/communication" ) @@ -13,7 +16,8 @@ type TOMLRules struct { BlockByTrust map[string]int32 `toml:"block_by_trust"` BlockDomains []string `toml:"block_domains"` AllowDomains []string `toml:"allow_domains"` - MinTrustLevel int32 `toml:"min_trust_level"` + MinTrustLevel int32 `toml:"min_trust_level"` + Extra map[string]interface{} `toml:",remain"` } type TOMLConfig struct { @@ -28,25 +32,22 @@ type PolicyManager struct { version uint64 } -func NewPolicyManager(configData []byte, version uint64) *PolicyManager { - pm := &PolicyManager{ - version: version, - config: &TOMLConfig{}, - } - if len(configData) > 0 { - var cfg TOMLConfig - if err := toml.Unmarshal(configData, &cfg); err != nil { - log.Printf("Failed to parse TOML: %v", err) - } else { - pm.config = &cfg - } - } - return pm +func NewPolicyManager() *PolicyManager { + return &PolicyManager{ + version: 0, + config: &TOMLConfig{}, + } +} + +func computeHash(policy *pb.WorkerPolicy) uint64 { + data, _ := proto.Marshal(policy) + hash := sha256.Sum256(data) + return binary.BigEndian.Uint64(hash[:8]) } func (pm *PolicyManager) GetWorkerPolicyProto(workerID uint64) *pb.WorkerPolicy { if pm.config == nil { - return &pb.WorkerPolicy{Version: pm.version} + return &pb.WorkerPolicy{} } policy := &pb.WorkerPolicy{ @@ -55,7 +56,6 @@ func (pm *PolicyManager) GetWorkerPolicyProto(workerID uint64) *pb.WorkerPolicy BlockDomains: make([]string, len(pm.config.Global.Rules.BlockDomains)), AllowDomains: make([]string, len(pm.config.Global.Rules.AllowDomains)), MinTrustLevel: pm.config.Global.Rules.MinTrustLevel, - Version: pm.version, } copy(policy.BlockCategories, pm.config.Global.Rules.BlockCategories) for k, v := range pm.config.Global.Rules.BlockByTrust { @@ -108,18 +108,30 @@ func (pm *PolicyManager) GetWorkerPolicyProto(workerID uint64) *pb.WorkerPolicy if filter.MinTrustLevel != 0 { policy.MinTrustLevel = filter.MinTrustLevel } - } + if len(filter.Extra) > 0 { + if extraStruct, err := structpb.NewStruct(filter.Extra); err == nil { + policy.Extra = extraStruct + } + } else { + if len(pm.config.Global.Rules.Extra) > 0 { + if extraStruct, err := structpb.NewStruct(pm.config.Global.Rules.Extra); err == nil { + policy.Extra = extraStruct + } + } + } + } + policy.PolicyHash = computeHash(policy) return policy } -func (pm *PolicyManager) UpdateConfig(configData []byte, version uint64) { +func (pm *PolicyManager) UpdateConfig(configData []byte) { var cfg TOMLConfig if err := toml.Unmarshal(configData, &cfg); err != nil { log.Printf("Failed to parse TOML in UpdateConfig: %v", err) return } pm.config = &cfg - pm.version = version - log.Printf("Config updated to version %d", version) + pm.version++ + log.Printf("Config updated to version %d", pm.version) } diff --git a/controller/pkg/proto/communication/BUILD b/controller/pkg/proto/communication/BUILD index d6785d4..5fc229e 100644 --- a/controller/pkg/proto/communication/BUILD +++ b/controller/pkg/proto/communication/BUILD @@ -5,6 +5,9 @@ proto_library( name = "communication_proto", srcs = ["communication.proto"], visibility = ["//visibility:public"], + deps = [ + "@protobuf//:struct_proto", + ], ) go_proto_library( diff --git a/controller/pkg/proto/communication/communication.proto b/controller/pkg/proto/communication/communication.proto index 2f2add4..03e953c 100644 --- a/controller/pkg/proto/communication/communication.proto +++ b/controller/pkg/proto/communication/communication.proto @@ -1,5 +1,6 @@ syntax = "proto3"; option go_package = "github.com/moevm/grpc_server/pkg/proto/communication"; +import "google/protobuf/struct.proto"; enum PulseType { PULSE_INVALD = 0; @@ -60,6 +61,8 @@ message WorkerResponse { message GetPolicyRequest { uint64 worker_id = 1; + uint64 policy_hash = 2; + uint64 config_version = 3; } message WorkerPolicy { @@ -68,15 +71,17 @@ message WorkerPolicy { repeated string block_domains = 3; repeated string allow_domains = 4; int32 min_trust_level = 5; - uint64 version = 6; + uint64 config_version = 6; + uint64 policy_hash = 7; + google.protobuf.Struct extra = 8; } message ClassifyRequest { uint64 worker_id = 1; - string domen = 2; + string domain = 2; } message ClassifyResponse { - string category = 1; + repeated string categories = 1; int32 trust_level = 2; } diff --git a/controller/test/BUILD b/controller/test/BUILD new file mode 100644 index 0000000..bcffe54 --- /dev/null +++ b/controller/test/BUILD @@ -0,0 +1,11 @@ +load("@rules_go//go:def.bzl", "go_test") + +go_test( + name = "integration_test", + srcs = ["worker_test.go"], + deps = [ + "@com_github_stretchr_testify//assert", + ], + args = ["--test.v"], + tags = ["exclusive"], +) \ No newline at end of file diff --git a/controller/test/README.md b/controller/test/README.md new file mode 100644 index 0000000..159ba36 --- /dev/null +++ b/controller/test/README.md @@ -0,0 +1,18 @@ +# Запуск теста + +### 1. Сборка компонентов + +```bash +cd controller +bazel build //cmd/grpc_server:grpc_server + +cd ../worker +bazel build //:worker +``` + +### 2. Запуск интеграционного теста + +```bash +cd ../controller +./test/run.sh +``` \ No newline at end of file diff --git a/controller/test/run.sh b/controller/test/run.sh new file mode 100755 index 0000000..ff2d1d7 --- /dev/null +++ b/controller/test/run.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +set -e + +PROJECT_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" + +cd "$PROJECT_ROOT/controller" + +echo "Сборка теста" +bazel build //test:integration_test + +echo "Запуск" +./bazel-bin/test/integration_test_/integration_test diff --git a/controller/test/worker_test.go b/controller/test/worker_test.go new file mode 100644 index 0000000..cb93491 --- /dev/null +++ b/controller/test/worker_test.go @@ -0,0 +1,76 @@ +package test + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func findProjectRoot() string { + dir, _ := os.Getwd() + for { + if _, err := os.Stat(filepath.Join(dir, "controller")); err != nil { + if _, err := os.Stat(filepath.Join(dir, "worker")); err != nil { + dir = filepath.Dir(dir) + continue + } + } + return dir + } +} + +func TestWorkerPolicyRequest(t *testing.T) { + root := findProjectRoot() + + ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") + workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") + + ctrl := exec.Command(ctrlBin) + ctrl.Start() + defer ctrl.Process.Kill() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + worker := exec.CommandContext(ctx, workerBin) + worker.Env = []string{ + "METRICS_GATEWAY_ADDRESS=localhost", + "METRICS_GATEWAY_PORT=9091", + "TEST_REQUEST_POLICY=true", + } + + output, err := worker.CombinedOutput() + assert.NoError(t, err, "Worker failed") + assert.Contains(t, string(output), "Worker 1 requests policy") + assert.Contains(t, string(output), "Policy received") +} + +func TestWorkerClassifyRequest(t *testing.T) { + root := findProjectRoot() + + ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") + workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") + + ctrl := exec.Command(ctrlBin) + ctrl.Start() + defer ctrl.Process.Kill() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + worker := exec.CommandContext(ctx, workerBin) + worker.Env = []string{ + "METRICS_GATEWAY_ADDRESS=localhost", + "METRICS_GATEWAY_PORT=9091", + "TEST_CLASSIFY_DOMAIN=facebook.com", + } + + output, err := worker.CombinedOutput() + assert.NoError(t, err, "Worker failed") + assert.Contains(t, string(output), "Domain 'facebook.com' classified as category") +} \ No newline at end of file diff --git a/worker/BUILD b/worker/BUILD index f53c957..1fe36cd 100644 --- a/worker/BUILD +++ b/worker/BUILD @@ -32,5 +32,8 @@ cc_proto_library( proto_library( name = "communication_proto", - srcs = [ "communication.proto" ] + srcs = [ "communication.proto" ], + deps = [ + "@protobuf//:struct_proto", + ], ) diff --git a/worker/communication.proto b/worker/communication.proto index 2f2add4..03e953c 100644 --- a/worker/communication.proto +++ b/worker/communication.proto @@ -1,5 +1,6 @@ syntax = "proto3"; option go_package = "github.com/moevm/grpc_server/pkg/proto/communication"; +import "google/protobuf/struct.proto"; enum PulseType { PULSE_INVALD = 0; @@ -60,6 +61,8 @@ message WorkerResponse { message GetPolicyRequest { uint64 worker_id = 1; + uint64 policy_hash = 2; + uint64 config_version = 3; } message WorkerPolicy { @@ -68,15 +71,17 @@ message WorkerPolicy { repeated string block_domains = 3; repeated string allow_domains = 4; int32 min_trust_level = 5; - uint64 version = 6; + uint64 config_version = 6; + uint64 policy_hash = 7; + google.protobuf.Struct extra = 8; } message ClassifyRequest { uint64 worker_id = 1; - string domen = 2; + string domain = 2; } message ClassifyResponse { - string category = 1; + repeated string categories = 1; int32 trust_level = 2; } diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index 645a577..90254fc 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -37,6 +37,8 @@ class Worker { int listener_fd = -1; uint64_t worker_id = 0; uint64_t current_task_id = 0; + uint64_t current_policy_hash = 0; + uint64_t current_config_version = 0; std::chrono::time_point last_pulse_time; uint64_t pulse_interval = MIN_PULSE_TIME; @@ -101,7 +103,7 @@ class Worker { inline uint64_t GetID() const { return worker_id; } void requestPolicyFromController(); - void classifyDomen(const std::string& domen); + void classifyDomain(const std::string& domain); WorkerState GetState() const { return state; } virtual void ProcessTask(const std::vector &data) = 0; void MainLoop(); diff --git a/worker/src/main.cpp b/worker/src/main.cpp index 1f0b31d..714b658 100644 --- a/worker/src/main.cpp +++ b/worker/src/main.cpp @@ -39,7 +39,31 @@ int main() { gateway_port); try { - HashWorker(gateway_address, gateway_port).MainLoop(); + HashWorker worker(gateway_address, gateway_port); + + bool test_mode = false; + + if (getenv("TEST_REQUEST_POLICY") != nullptr) { + test_mode = true; + spdlog::info("Test mode: requesting policy"); + std::this_thread::sleep_for(std::chrono::seconds(2)); + worker.requestPolicyFromController(); + } + + if (const char* domain = getenv("TEST_CLASSIFY_DOMAIN")) { + test_mode = true; + spdlog::info("Test mode: classifying domain '{}'", domain); + std::this_thread::sleep_for(std::chrono::seconds(1)); + worker.classifyDomain(domain); + } + + if (test_mode) { + spdlog::info("Test mode completed, exiting"); + return 0; + } + + worker.MainLoop(); + } catch (WorkerException &e) { spdlog::error(e.what()); return 1; diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index 06b90e4..4860d01 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -132,6 +132,8 @@ void Worker::requestPolicyFromController() { GetPolicyRequest req; req.set_worker_id(worker_id); + req.set_policy_hash(current_policy_hash); + req.set_config_version(current_config_version); main_fd = socket(AF_UNIX, SOCK_STREAM, 0); if (main_fd < 0) @@ -147,7 +149,7 @@ void Worker::requestPolicyFromController() { WriteProtoMessage(main_fd, req); WorkerPolicy policy; ReadProtoMessage(main_fd, policy); - + current_policy_hash = policy.policy_hash(); spdlog::info("Policy received", policy.ShortDebugString()); close(main_fd); @@ -159,13 +161,13 @@ void Worker::requestPolicyFromController() { } } -void Worker::classifyDomen(const std::string& domen){ +void Worker::classifyDomain(const std::string& domain) { int main_fd = 0; try{ - spdlog::info("Worker {} classifying domen '{}'", worker_id, domen); + spdlog::info("Worker {} classifying domain '{}'", worker_id, domain); ClassifyRequest req; req.set_worker_id(worker_id); - req.set_domen(domen); + req.set_domain(domain); main_fd = socket(AF_UNIX, SOCK_STREAM, 0); if (main_fd < 0) @@ -183,7 +185,7 @@ void Worker::classifyDomen(const std::string& domen){ ClassifyResponse resp; ReadProtoMessage(main_fd, resp); - spdlog::info("Domen '{}' classified as category '{}' with trust level {}", domen, resp.category(), resp.trust_level()); + spdlog::info("Domain '{}' classified as category '{}' with trust level {}", domain, resp.categories(0), resp.trust_level()); close(main_fd); From dada1dab8101feb38d6bfe9652bbfcd9d5fae2d2 Mon Sep 17 00:00:00 2001 From: stepanrodimanov Date: Thu, 26 Feb 2026 20:12:09 +0300 Subject: [PATCH 4/8] update 3 --- worker/include/worker.hpp | 2 +- worker/src/main.cpp | 22 +++++++++++----------- worker/src/worker.cpp | 38 ++++++++++++++++++++------------------ 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index 90254fc..c968072 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -103,7 +103,7 @@ class Worker { inline uint64_t GetID() const { return worker_id; } void requestPolicyFromController(); - void classifyDomain(const std::string& domain); + void classifyDomain(const std::string &domain); WorkerState GetState() const { return state; } virtual void ProcessTask(const std::vector &data) = 0; void MainLoop(); diff --git a/worker/src/main.cpp b/worker/src/main.cpp index 714b658..9767793 100644 --- a/worker/src/main.cpp +++ b/worker/src/main.cpp @@ -44,22 +44,22 @@ int main() { bool test_mode = false; if (getenv("TEST_REQUEST_POLICY") != nullptr) { - test_mode = true; - spdlog::info("Test mode: requesting policy"); - std::this_thread::sleep_for(std::chrono::seconds(2)); - worker.requestPolicyFromController(); + test_mode = true; + spdlog::info("Test mode: requesting policy"); + std::this_thread::sleep_for(std::chrono::seconds(2)); + worker.requestPolicyFromController(); } - if (const char* domain = getenv("TEST_CLASSIFY_DOMAIN")) { - test_mode = true; - spdlog::info("Test mode: classifying domain '{}'", domain); - std::this_thread::sleep_for(std::chrono::seconds(1)); - worker.classifyDomain(domain); + if (const char *domain = getenv("TEST_CLASSIFY_DOMAIN")) { + test_mode = true; + spdlog::info("Test mode: classifying domain '{}'", domain); + std::this_thread::sleep_for(std::chrono::seconds(1)); + worker.classifyDomain(domain); } if (test_mode) { - spdlog::info("Test mode completed, exiting"); - return 0; + spdlog::info("Test mode completed, exiting"); + return 0; } worker.MainLoop(); diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index 4860d01..f90565d 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -129,15 +129,15 @@ void Worker::requestPolicyFromController() { int main_fd = 0; try { spdlog::info("Worker {} requests policy", worker_id); - + GetPolicyRequest req; req.set_worker_id(worker_id); req.set_policy_hash(current_policy_hash); req.set_config_version(current_config_version); - + main_fd = socket(AF_UNIX, SOCK_STREAM, 0); if (main_fd < 0) - throw WorkerException(std::string("socket: ") + strerror(errno)); + throw WorkerException(std::string("socket: ") + strerror(errno)); sockaddr_un addr{}; addr.sun_family = AF_UNIX; @@ -145,7 +145,7 @@ void Worker::requestPolicyFromController() { sizeof(addr.sun_path) - 1); if (connect(main_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) - throw WorkerException(std::string("connect: ") + strerror(errno)); + throw WorkerException(std::string("connect: ") + strerror(errno)); WriteProtoMessage(main_fd, req); WorkerPolicy policy; ReadProtoMessage(main_fd, policy); @@ -154,41 +154,43 @@ void Worker::requestPolicyFromController() { close(main_fd); } catch (const std::exception &e) { - close(main_fd); - SetState(WorkerState::ERROR); - spdlog::error("requestPolicyFromController failed: {}", e.what()); - throw WorkerException(std::string("requestPolicyFromController: ") + e.what()); + close(main_fd); + SetState(WorkerState::ERROR); + spdlog::error("requestPolicyFromController failed: {}", e.what()); + throw WorkerException(std::string("requestPolicyFromController: ") + + e.what()); } } -void Worker::classifyDomain(const std::string& domain) { +void Worker::classifyDomain(const std::string &domain) { int main_fd = 0; - try{ + try { spdlog::info("Worker {} classifying domain '{}'", worker_id, domain); ClassifyRequest req; req.set_worker_id(worker_id); req.set_domain(domain); main_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (main_fd < 0) - throw WorkerException(std::string("socket: ") + strerror(errno)); + if (main_fd < 0) + throw WorkerException(std::string("socket: ") + strerror(errno)); sockaddr_un addr{}; addr.sun_family = AF_UNIX; strncpy(addr.sun_path, SOCKET_DIR CLASSIFY_SOCKET_NAME, sizeof(addr.sun_path) - 1); - + if (connect(main_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) throw WorkerException(std::string("connect: ") + strerror(errno)); - + WriteProtoMessage(main_fd, req); ClassifyResponse resp; ReadProtoMessage(main_fd, resp); - - spdlog::info("Domain '{}' classified as category '{}' with trust level {}", domain, resp.categories(0), resp.trust_level()); - + + spdlog::info("Domain '{}' classified as category '{}' with trust level {}", + domain, resp.categories(0), resp.trust_level()); + close(main_fd); - + } catch (const std::exception &e) { close(main_fd); SetState(WorkerState::ERROR); From cda751c0b9a98755fb7e8f6384b97b304ab23c94 Mon Sep 17 00:00:00 2001 From: stepanrodimanov Date: Wed, 4 Mar 2026 17:44:47 +0300 Subject: [PATCH 5/8] update 4 --- controller/internal/manager/manager.go | 48 +++++++++++++++++++ .../proto/communication/communication.proto | 14 ++++++ controller/test/worker_test.go | 26 ++++++++++ worker/communication.proto | 14 ++++++ worker/include/worker.hpp | 2 + worker/src/main.cpp | 7 +++ worker/src/worker.cpp | 29 +++++++++++ 7 files changed, 140 insertions(+) diff --git a/controller/internal/manager/manager.go b/controller/internal/manager/manager.go index 28c0ac2..ec490db 100644 --- a/controller/internal/manager/manager.go +++ b/controller/internal/manager/manager.go @@ -39,6 +39,7 @@ const ( workerMainSocketPath = "/run/controller/main.sock" workerPolicySocketPath = "/run/controller/policy.sock" workerClassifySocketPath = "/run/controller/classify.sock" + workerStatsSocketPath = "/run/controller/stats.sock" workerSocketPath = "/run/controller/" ) @@ -51,6 +52,7 @@ type Manager struct { listener net.Listener policyListener net.Listener classifyListener net.Listener + statsListener net.Listener policyManager *PolicyManager workers map[uint64]*Worker @@ -367,6 +369,24 @@ func (m *Manager) classifyLoop() { } } +func (m *Manager) statsLoop() { + log.Print("Listening on stats.sock") + for { + select { + case <-m.shutdown: + return + default: + netConn, err := m.statsListener.Accept() + if err != nil { + log.Printf("Stats accept error: %v", err) + continue + } + go m.handleStatsConnection(conn.Unix{Conn: netConn}) + } + } +} + + func (m *Manager) handlePolicyConnection(conn conn.Unix) { defer conn.Close() @@ -436,6 +456,26 @@ func (m *Manager) handleClassifyConnection(conn conn.Unix) { conn.WriteMessage(respData) log.Printf("Classify response sent to worker %d", req.WorkerId) } + +func (m *Manager) handleStatsConnection(conn conn.Unix) { + defer conn.Close() + + msgData, err := conn.ReadMessage() + if err != nil { + m.errorChan <- fmt.Errorf("stats read error: %w", err) + return + } + + var req communication.StatsReport + if err := proto.Unmarshal(msgData, &req); err != nil { + m.errorChan <- fmt.Errorf("stats unmarshal error: %w", err) + return + } + + log.Printf("Stats worker %d", req.WorkerId) + +} + // errorHandler catches the errors from goroutines and logs them. // This function should always be run in a goroutine. func (m *Manager) errorHandler() { @@ -655,10 +695,16 @@ func NewManager() (*Manager, error) { return nil, fmt.Errorf("failed to create classify listener: %w", err) } + statsListener, err := net.Listen("unix", workerStatsSocketPath) + if err != nil { + return nil, fmt.Errorf("failed to create stats listener: %w", err) + } + m := &Manager{ listener: listener, policyListener: policyListener, classifyListener: classifyListener, + statsListener: statsListener, policyManager: NewPolicyManager(), workers: make(map[uint64]*Worker), tasks: make(map[uint64]*Task), @@ -673,6 +719,7 @@ func NewManager() (*Manager, error) { go m.mainLoop() go m.policyLoop() go m.classifyLoop() + go m.statsLoop() go m.errorHandler() go m.dispatchTasks() go m.checkHealth() @@ -687,6 +734,7 @@ func (m *Manager) Shutdown() { m.listener.Close() m.policyListener.Close() m.classifyListener.Close() + m.statsListener.Close() close(m.freeWorkers) close(m.fetchWorkers) diff --git a/controller/pkg/proto/communication/communication.proto b/controller/pkg/proto/communication/communication.proto index 03e953c..7477f65 100644 --- a/controller/pkg/proto/communication/communication.proto +++ b/controller/pkg/proto/communication/communication.proto @@ -85,3 +85,17 @@ message ClassifyResponse { repeated string categories = 1; int32 trust_level = 2; } + +message ResourceStats { + string domain = 1; + uint64 blosked = 2; + uint64 allowed = 3; +} + +message StatsReport { + uint64 worker_id = 1; + uint64 time = 2; + uint64 total_blocked = 3; + uint64 total_allowed = 4; + repeated ResourceStats resources = 5; +} diff --git a/controller/test/worker_test.go b/controller/test/worker_test.go index cb93491..481bbf7 100644 --- a/controller/test/worker_test.go +++ b/controller/test/worker_test.go @@ -50,6 +50,32 @@ func TestWorkerPolicyRequest(t *testing.T) { assert.Contains(t, string(output), "Policy received") } +func TestWorkerStatsReport(t *testing.T) { + root := findProjectRoot() + + ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") + workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") + + ctrl := exec.Command(ctrlBin) + ctrl.Start() + defer ctrl.Process.Kill() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + worker := exec.CommandContext(ctx, workerBin) + worker.Env = []string{ + "METRICS_GATEWAY_ADDRESS=localhost", + "METRICS_GATEWAY_PORT=9091", + "TEST_STATS=true", + } + + output, err := worker.CombinedOutput() + assert.NoError(t, err, "Worker failed") + assert.Contains(t, string(output), "Worker 1 send stats") + assert.Contains(t, string(output), "Policy received") +} + func TestWorkerClassifyRequest(t *testing.T) { root := findProjectRoot() diff --git a/worker/communication.proto b/worker/communication.proto index 03e953c..7477f65 100644 --- a/worker/communication.proto +++ b/worker/communication.proto @@ -85,3 +85,17 @@ message ClassifyResponse { repeated string categories = 1; int32 trust_level = 2; } + +message ResourceStats { + string domain = 1; + uint64 blosked = 2; + uint64 allowed = 3; +} + +message StatsReport { + uint64 worker_id = 1; + uint64 time = 2; + uint64 total_blocked = 3; + uint64 total_allowed = 4; + repeated ResourceStats resources = 5; +} diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index c968072..06a8e5c 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -10,6 +10,7 @@ #define MAIN_SOCKET_NAME "main.sock" #define POLICY_SOCKET_NAME "policy.sock" #define CLASSIFY_SOCKET_NAME "classify.sock" +#define STATS_SOCKET_NAME "stats.sock" #define EXPECTED_PULSE_TIME 60 #define MIN_PULSE_TIME 30 #define MAX_PULSE_TIME 45 @@ -104,6 +105,7 @@ class Worker { inline uint64_t GetID() const { return worker_id; } void requestPolicyFromController(); void classifyDomain(const std::string &domain); + void statsReport(); WorkerState GetState() const { return state; } virtual void ProcessTask(const std::vector &data) = 0; void MainLoop(); diff --git a/worker/src/main.cpp b/worker/src/main.cpp index 9767793..8143472 100644 --- a/worker/src/main.cpp +++ b/worker/src/main.cpp @@ -50,6 +50,13 @@ int main() { worker.requestPolicyFromController(); } + if (getenv("TEST_STATS") != nullptr) { + test_mode = true; + spdlog::info("Test mode: send stats"); + std::this_thread::sleep_for(std::chrono::seconds(2)); + worker.statsReport(); + } + if (const char *domain = getenv("TEST_CLASSIFY_DOMAIN")) { test_mode = true; spdlog::info("Test mode: classifying domain '{}'", domain); diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index f90565d..7c46a45 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -198,6 +198,35 @@ void Worker::classifyDomain(const std::string &domain) { } } +void Worker::statsReport() { + int main_fd = 0; + try { + spdlog::info("Worker {} send stats", worker_id); + StatsReport req; + req.set_worker_id(worker_id); + + main_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (main_fd < 0) + throw WorkerException(std::string("socket: ") + strerror(errno)); + + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, SOCKET_DIR STATS_SOCKET_NAME, + sizeof(addr.sun_path) - 1); + + if (connect(main_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) + throw WorkerException(std::string("connect: ") + strerror(errno)); + + WriteProtoMessage(main_fd, req); + close(main_fd); + + } catch (const std::exception &e) { + close(main_fd); + SetState(WorkerState::ERROR); + throw WorkerException(std::string("statsReport: ") + e.what()); + } +} + Worker::Worker() : listener_fd(-1), state(WorkerState::BOOTING) { SendPulse(PULSE_REGISTER); From 412d08158525d9da8e35a23a8e0b14631fb90338 Mon Sep 17 00:00:00 2001 From: stepanrodimanov Date: Wed, 4 Mar 2026 17:48:49 +0300 Subject: [PATCH 6/8] update 5 --- worker/src/worker.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index 7c46a45..a5bf3eb 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -204,7 +204,7 @@ void Worker::statsReport() { spdlog::info("Worker {} send stats", worker_id); StatsReport req; req.set_worker_id(worker_id); - + main_fd = socket(AF_UNIX, SOCK_STREAM, 0); if (main_fd < 0) throw WorkerException(std::string("socket: ") + strerror(errno)); From 5b10701dd015bee2aa696a578ff40b2998934beb Mon Sep 17 00:00:00 2001 From: stepanrodimanov Date: Thu, 5 Mar 2026 02:32:55 +0300 Subject: [PATCH 7/8] update 6 --- controller/cmd/grpc_server/BUILD | 1 + controller/cmd/grpc_server/main.go | 4 + controller/internal/grpcserver/BUILD | 3 + controller/internal/grpcserver/data_server.go | 44 ++++ controller/internal/manager/manager.go | 204 ++---------------- controller/pkg/proto/communication/BUILD | 1 + .../proto/communication/communication.proto | 9 +- worker/BUILD | 94 +++++--- worker/MODULE.bazel | 8 +- worker/communication.proto | 9 +- worker/include/worker.hpp | 6 + worker/src/worker.cpp | 93 ++++---- 12 files changed, 203 insertions(+), 273 deletions(-) create mode 100644 controller/internal/grpcserver/data_server.go diff --git a/controller/cmd/grpc_server/BUILD b/controller/cmd/grpc_server/BUILD index 7aae62a..6923a67 100644 --- a/controller/cmd/grpc_server/BUILD +++ b/controller/cmd/grpc_server/BUILD @@ -17,6 +17,7 @@ go_library( "//internal/manager", "//pkg/proto/admin_service:admin_service_go_proto", "//pkg/proto/file_service:file_service_go_proto", + "//pkg/proto/communication:communication_go_proto", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//reflection", ], diff --git a/controller/cmd/grpc_server/main.go b/controller/cmd/grpc_server/main.go index 96e05a2..46d4ffd 100644 --- a/controller/cmd/grpc_server/main.go +++ b/controller/cmd/grpc_server/main.go @@ -11,6 +11,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/reflection" adminPb "github.com/moevm/grpc_server/pkg/proto/admin_service" + commPb "github.com/moevm/grpc_server/pkg/proto/communication" ) func main() { @@ -23,6 +24,8 @@ func main() { adminServer.SetManager(mgr) + dataServer := grpcserver.NewDataServer(mgr) + lis, err := net.Listen("tcp", net.JoinHostPort(cfg.Host, cfg.Port)) if err != nil { log.Fatalf("failed to listen: %v", err) @@ -36,6 +39,7 @@ func main() { service := grpc.NewServer(serverOpts...) adminPb.RegisterAdminServiceServer(service, adminServer) pb.RegisterFileServiceServer(service, grpcserver.NewServer(cfg.AllowedChars, mgr)) + commPb.RegisterDataServiceServer(service, dataServer) reflection.Register(service) log.Printf("Server starting on %s:%s", cfg.Host, cfg.Port) diff --git a/controller/internal/grpcserver/BUILD b/controller/internal/grpcserver/BUILD index 78ba59f..818872c 100644 --- a/controller/internal/grpcserver/BUILD +++ b/controller/internal/grpcserver/BUILD @@ -5,6 +5,7 @@ go_library( srcs = [ "server.go", "admin_server.go", + "data_server.go", ], importpath = "github.com/moevm/grpc_server/internal/grpcserver", visibility = ["//visibility:public"], @@ -12,9 +13,11 @@ go_library( "//internal/manager", "//pkg/proto/file_service:file_service_go_proto", "//pkg/proto/admin_service:admin_service_go_proto", + "//pkg/proto/communication:communication_go_proto", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_protobuf//types/known/emptypb", ] ) diff --git a/controller/internal/grpcserver/data_server.go b/controller/internal/grpcserver/data_server.go new file mode 100644 index 0000000..48af5c2 --- /dev/null +++ b/controller/internal/grpcserver/data_server.go @@ -0,0 +1,44 @@ +package grpcserver + +import ( + "context" + "log" + + "github.com/moevm/grpc_server/internal/manager" + pb "github.com/moevm/grpc_server/pkg/proto/communication" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" +) + +type DataServer struct { + pb.UnimplementedDataServiceServer + manager *manager.Manager +} + +func NewDataServer(mgr *manager.Manager) *DataServer { + return &DataServer{manager: mgr} +} + +func (s *DataServer) GetPolicy(ctx context.Context, req *pb.GetPolicyRequest) (*pb.WorkerPolicy, error) { + log.Printf("gRPC GetPolicy from worker %d (hash: %d)", req.WorkerId, req.PolicyHash) + policy := s.manager.GetWorkerPolicy(req.WorkerId) + if policy == nil { + return nil, status.Errorf(codes.NotFound, "no policy for worker %d", req.WorkerId) + } + return policy, nil +} + +func (s *DataServer) Classify(ctx context.Context, req *pb.ClassifyRequest) (*pb.ClassifyResponse, error) { + log.Printf("gRPC Classify from worker %d for domain: %s", req.WorkerId, req.Domain) + return &pb.ClassifyResponse{ + Categories: []string{"unknown"}, + TrustLevel: 50, + }, nil +} + +func (s *DataServer) SendStats(ctx context.Context, report *pb.StatsReport) (*emptypb.Empty, error) { + log.Printf("gRPC Stats from worker %d: blocked=%d allowed=%d", + report.WorkerId, report.TotalBlocked, report.TotalAllowed) + return &emptypb.Empty{}, nil +} \ No newline at end of file diff --git a/controller/internal/manager/manager.go b/controller/internal/manager/manager.go index ec490db..ccf45b3 100644 --- a/controller/internal/manager/manager.go +++ b/controller/internal/manager/manager.go @@ -37,10 +37,7 @@ const ( const ( workerMainSocketPath = "/run/controller/main.sock" - workerPolicySocketPath = "/run/controller/policy.sock" - workerClassifySocketPath = "/run/controller/classify.sock" - workerStatsSocketPath = "/run/controller/stats.sock" - workerSocketPath = "/run/controller/" + workerSocketPath = "/run/controller/" ) type IManager interface { @@ -50,14 +47,11 @@ type IManager interface { type Manager struct { listener net.Listener - policyListener net.Listener - classifyListener net.Listener - statsListener net.Listener policyManager *PolicyManager - workers map[uint64]*Worker - workersMutex sync.Mutex - workerId uint64 + workers map[uint64]*Worker + workersMutex sync.Mutex + workerId uint64 tasks map[uint64]*Task tasksMutex sync.Mutex @@ -335,147 +329,6 @@ func (m *Manager) mainLoop() { } -func (m *Manager) policyLoop() { - log.Print("Listening on policy.sock") - for { - select { - case <-m.shutdown: - return - default: - netConn, err := m.policyListener.Accept() - if err != nil { - log.Printf("Policy accept error: %v", err) - continue - } - go m.handlePolicyConnection(conn.Unix{Conn: netConn}) - } - } -} - -func (m *Manager) classifyLoop() { - log.Print("Listening on classify.sock") - for { - select { - case <-m.shutdown: - return - default: - netConn, err := m.classifyListener.Accept() - if err != nil { - log.Printf("Classify accept error: %v", err) - continue - } - go m.handleClassifyConnection(conn.Unix{Conn: netConn}) - } - } -} - -func (m *Manager) statsLoop() { - log.Print("Listening on stats.sock") - for { - select { - case <-m.shutdown: - return - default: - netConn, err := m.statsListener.Accept() - if err != nil { - log.Printf("Stats accept error: %v", err) - continue - } - go m.handleStatsConnection(conn.Unix{Conn: netConn}) - } - } -} - - -func (m *Manager) handlePolicyConnection(conn conn.Unix) { - defer conn.Close() - - msgData, err := conn.ReadMessage() - if err != nil { - m.errorChan <- fmt.Errorf("policy read error: %w", err) - return - } - - var req communication.GetPolicyRequest - if err := proto.Unmarshal(msgData, &req); err != nil { - m.errorChan <- fmt.Errorf("policy unmarshal error: %w", err) - return - } - - log.Printf("Policy request received from worker %d", req.WorkerId) - - policyBytes, needUpdate, err := m.HandleGetPolicy(req.WorkerId, req.PolicyHash) - if err != nil { - log.Printf("Error getting policy: %v", err) - return - } - - if !needUpdate { - log.Printf("Worker %d already has latest policy", req.WorkerId) - resp := &communication.WorkerResponse{ - Error: communication.WorkerError_WORKER_ERR_OK, - } - respBytes, _ := proto.Marshal(resp) - conn.WriteMessage(respBytes) - return - } - - conn.WriteMessage(policyBytes) - log.Printf("Policy sent to worker %d", req.WorkerId) -} - -func (m *Manager) handleClassifyConnection(conn conn.Unix) { - defer conn.Close() - - msgData, err := conn.ReadMessage() - if err != nil { - m.errorChan <- fmt.Errorf("classify read error: %w", err) - return - } - - var req communication.ClassifyRequest - if err := proto.Unmarshal(msgData, &req); err != nil { - m.errorChan <- fmt.Errorf("classify unmarshal error: %w", err) - return - } - - log.Printf("Classify request received from worker %d for domen '%s'", - req.WorkerId, req.Domain) - - resp := &communication.ClassifyResponse{ - Categories: []string{"unknown"}, - TrustLevel: 50, - } - - respData, err := proto.Marshal(resp) - if err != nil { - m.errorChan <- fmt.Errorf("marshal classify response error: %w", err) - return - } - - conn.WriteMessage(respData) - log.Printf("Classify response sent to worker %d", req.WorkerId) -} - -func (m *Manager) handleStatsConnection(conn conn.Unix) { - defer conn.Close() - - msgData, err := conn.ReadMessage() - if err != nil { - m.errorChan <- fmt.Errorf("stats read error: %w", err) - return - } - - var req communication.StatsReport - if err := proto.Unmarshal(msgData, &req); err != nil { - m.errorChan <- fmt.Errorf("stats unmarshal error: %w", err) - return - } - - log.Printf("Stats worker %d", req.WorkerId) - -} - // errorHandler catches the errors from goroutines and logs them. // This function should always be run in a goroutine. func (m *Manager) errorHandler() { @@ -653,20 +506,24 @@ func removeContents(dir string) error { } func (m *Manager) HandleGetPolicy(workerID uint64, currentHash uint64) ([]byte, bool, error) { - log.Printf("Worker %d requested policy", workerID) + log.Printf("Worker %d requested policy", workerID) + + policyProto := m.policyManager.GetWorkerPolicyProto(workerID) - policyProto := m.policyManager.GetWorkerPolicyProto(workerID) + if currentHash == policyProto.PolicyHash { + return nil, false, nil + } - if currentHash == policyProto.PolicyHash { - return nil, false, nil - } + policyBytes, err := proto.Marshal(policyProto) + if err != nil { + return nil, false, err + } - policyBytes, err := proto.Marshal(policyProto) - if err != nil { - return nil, false, err - } + return policyBytes, true, nil +} - return policyBytes, true, nil +func (m *Manager) GetWorkerPolicy(workerID uint64) *communication.WorkerPolicy { + return m.policyManager.GetWorkerPolicyProto(workerID) } func (m *Manager) UpdateConfig(configData []byte) { @@ -685,26 +542,9 @@ func NewManager() (*Manager, error) { return nil, fmt.Errorf("failed to create listener: %w", err) } - policyListener, err := net.Listen("unix", workerPolicySocketPath) - if err != nil { - return nil, fmt.Errorf("failed to create policy listener: %w", err) - } - - classifyListener, err := net.Listen("unix", workerClassifySocketPath) - if err != nil { - return nil, fmt.Errorf("failed to create classify listener: %w", err) - } - - statsListener, err := net.Listen("unix", workerStatsSocketPath) - if err != nil { - return nil, fmt.Errorf("failed to create stats listener: %w", err) - } m := &Manager{ listener: listener, - policyListener: policyListener, - classifyListener: classifyListener, - statsListener: statsListener, policyManager: NewPolicyManager(), workers: make(map[uint64]*Worker), tasks: make(map[uint64]*Task), @@ -717,9 +557,6 @@ func NewManager() (*Manager, error) { } go m.mainLoop() - go m.policyLoop() - go m.classifyLoop() - go m.statsLoop() go m.errorHandler() go m.dispatchTasks() go m.checkHealth() @@ -732,9 +569,6 @@ func (m *Manager) Shutdown() { m.shutdownOnce.Do(func() { close(m.shutdown) m.listener.Close() - m.policyListener.Close() - m.classifyListener.Close() - m.statsListener.Close() close(m.freeWorkers) close(m.fetchWorkers) @@ -742,4 +576,4 @@ func (m *Manager) Shutdown() { close(m.taskSolutions) close(m.errorChan) }) -} +} \ No newline at end of file diff --git a/controller/pkg/proto/communication/BUILD b/controller/pkg/proto/communication/BUILD index 5fc229e..d967a96 100644 --- a/controller/pkg/proto/communication/BUILD +++ b/controller/pkg/proto/communication/BUILD @@ -7,6 +7,7 @@ proto_library( visibility = ["//visibility:public"], deps = [ "@protobuf//:struct_proto", + "@protobuf//:empty_proto", ], ) diff --git a/controller/pkg/proto/communication/communication.proto b/controller/pkg/proto/communication/communication.proto index 7477f65..185fc54 100644 --- a/controller/pkg/proto/communication/communication.proto +++ b/controller/pkg/proto/communication/communication.proto @@ -1,6 +1,13 @@ syntax = "proto3"; option go_package = "github.com/moevm/grpc_server/pkg/proto/communication"; import "google/protobuf/struct.proto"; +import "google/protobuf/empty.proto"; + +service DataService { + rpc GetPolicy(GetPolicyRequest) returns (WorkerPolicy); + rpc Classify(ClassifyRequest) returns (ClassifyResponse); + rpc SendStats(StatsReport) returns (google.protobuf.Empty); +} enum PulseType { PULSE_INVALD = 0; @@ -88,7 +95,7 @@ message ClassifyResponse { message ResourceStats { string domain = 1; - uint64 blosked = 2; + uint64 blocked = 2; uint64 allowed = 3; } diff --git a/worker/BUILD b/worker/BUILD index 1fe36cd..047688c 100644 --- a/worker/BUILD +++ b/worker/BUILD @@ -1,28 +1,12 @@ -cc_binary( - name = "worker", - srcs = [ - "src/main.cpp", - "src/md_calculator.cpp", - "src/file.cpp", - "src/worker.cpp", - "src/metrics_collector.cpp", - "include/file.hpp", - "include/md_calculator.hpp", - "include/worker.hpp", - "include/metrics_collector.hpp", - ], - deps = ["@spdlog//:spdlog", - "@curl//:curl", - ":communication_cc_proto"], - includes = ["/usr/local/include"], - linkopts = [ - "-L/usr/local/openssl/lib", - "-lssl", - "-lcrypto", - "-L/usr/local/lib", - "-lprometheus-cpp-push", - "-lprometheus-cpp-core", - ], +load("@grpc//bazel:generate_cc.bzl", "generate_cc") + +proto_library( + name = "communication_proto", + srcs = ["communication.proto"], + deps = [ + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:struct_proto", + ], ) cc_proto_library( @@ -30,10 +14,62 @@ cc_proto_library( deps = [":communication_proto"], ) -proto_library( - name = "communication_proto", - srcs = [ "communication.proto" ], +generate_cc( + name = "communication_cc_grpc_gen", + srcs = [":communication_proto"], + plugin = "@grpc//src/compiler:grpc_cpp_plugin", + well_known_protos = True, + generate_mocks = True, +) + +cc_library( + name = "communication_cc_grpc", + srcs = [":communication_cc_grpc_gen"], + hdrs = [":communication_cc_grpc_gen"], deps = [ - "@protobuf//:struct_proto", + ":communication_cc_proto", + "@grpc//:grpc++", + ], +) + +cc_library( + name = "worker_headers", + hdrs = [ + "include/worker.hpp", + "include/md_calculator.hpp", + "include/file.hpp", + "include/metrics_collector.hpp", + ], + srcs = [], + visibility = ["//visibility:public"], +) + +cc_binary( + name = "worker", + srcs = [ + "src/main.cpp", + "src/md_calculator.cpp", + "src/file.cpp", + "src/worker.cpp", + "src/metrics_collector.cpp", + ], + deps = [ + ":worker_headers", + ":communication_cc_proto", + ":communication_cc_grpc", + "@grpc//:grpc++", + "@spdlog//:spdlog", + "@curl//:curl", + ], + copts = [ + "-I$(GENDIR)/..", + ], + linkopts = [ + "-L/usr/local/openssl/lib", + "-lssl", + "-lcrypto", + "-L/usr/local/lib", + "-lprometheus-cpp-push", + "-lprometheus-cpp-core", ], ) diff --git a/worker/MODULE.bazel b/worker/MODULE.bazel index dc3d2aa..6f3f40d 100644 --- a/worker/MODULE.bazel +++ b/worker/MODULE.bazel @@ -1,10 +1,14 @@ module( - name = "worker", + name = "worker", + version = "0.1.0", ) bazel_dep(name = "rules_cc", version = "0.1.1") bazel_dep(name = "platforms", version = "0.0.11") bazel_dep(name = "spdlog", version = "1.15.2") bazel_dep(name = "abseil-cpp", version = "20250512.1") -bazel_dep(name = "protobuf", version = "32.0-rc1") bazel_dep(name = "curl", version = "8.8.0.bcr.3") +bazel_dep(name = "bazel_skylib", version = "1.8.1") +bazel_dep(name = "rules_proto", version = "7.0.2") +bazel_dep(name = "protobuf", version = "31.1", repo_name = "com_google_protobuf") +bazel_dep(name = "grpc", version = "1.78.0") diff --git a/worker/communication.proto b/worker/communication.proto index 7477f65..185fc54 100644 --- a/worker/communication.proto +++ b/worker/communication.proto @@ -1,6 +1,13 @@ syntax = "proto3"; option go_package = "github.com/moevm/grpc_server/pkg/proto/communication"; import "google/protobuf/struct.proto"; +import "google/protobuf/empty.proto"; + +service DataService { + rpc GetPolicy(GetPolicyRequest) returns (WorkerPolicy); + rpc Classify(ClassifyRequest) returns (ClassifyResponse); + rpc SendStats(StatsReport) returns (google.protobuf.Empty); +} enum PulseType { PULSE_INVALD = 0; @@ -88,7 +95,7 @@ message ClassifyResponse { message ResourceStats { string domain = 1; - uint64 blosked = 2; + uint64 blocked = 2; uint64 allowed = 3; } diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index 06a8e5c..4ae473c 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -5,6 +5,10 @@ #include #include +#include +#include "communication.grpc.pb.h" + +#include #define SOCKET_DIR "/run/controller/" #define MAIN_SOCKET_NAME "main.sock" @@ -46,6 +50,8 @@ class Worker { std::string fetch_data; std::string extra_data; + std::unique_ptr stub_; + enum class InitResponse : uint64_t { OK = 1 }; WorkerState state; void LogStateChange(WorkerState new_state); diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index a5bf3eb..a5deeec 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include "communication.grpc.pb.h" void Worker::LogStateChange(WorkerState new_state) { const char *state_names[] = {"BOOTING", "FREE", "BUSY", "SHUTTING_DOWN", @@ -126,7 +128,6 @@ void Worker::SendPulse(PulseType type) { } void Worker::requestPolicyFromController() { - int main_fd = 0; try { spdlog::info("Worker {} requests policy", worker_id); @@ -135,101 +136,83 @@ void Worker::requestPolicyFromController() { req.set_policy_hash(current_policy_hash); req.set_config_version(current_config_version); - main_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (main_fd < 0) - throw WorkerException(std::string("socket: ") + strerror(errno)); + WorkerPolicy policy; + grpc::ClientContext context; - sockaddr_un addr{}; - addr.sun_family = AF_UNIX; - strncpy(addr.sun_path, SOCKET_DIR POLICY_SOCKET_NAME, - sizeof(addr.sun_path) - 1); + auto status = stub_->GetPolicy(&context, req, &policy); + if (!status.ok()) { + throw WorkerException("GetPolicy failed: " + status.error_message()); + } - if (connect(main_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) - throw WorkerException(std::string("connect: ") + strerror(errno)); - WriteProtoMessage(main_fd, req); - WorkerPolicy policy; - ReadProtoMessage(main_fd, policy); current_policy_hash = policy.policy_hash(); - spdlog::info("Policy received", policy.ShortDebugString()); + spdlog::info("Policy received"); - close(main_fd); } catch (const std::exception &e) { - close(main_fd); SetState(WorkerState::ERROR); - spdlog::error("requestPolicyFromController failed: {}", e.what()); - throw WorkerException(std::string("requestPolicyFromController: ") + - e.what()); + throw WorkerException(std::string("requestPolicyFromController: ") + e.what()); } } void Worker::classifyDomain(const std::string &domain) { - int main_fd = 0; try { spdlog::info("Worker {} classifying domain '{}'", worker_id, domain); + ClassifyRequest req; req.set_worker_id(worker_id); req.set_domain(domain); - main_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (main_fd < 0) - throw WorkerException(std::string("socket: ") + strerror(errno)); - - sockaddr_un addr{}; - addr.sun_family = AF_UNIX; - strncpy(addr.sun_path, SOCKET_DIR CLASSIFY_SOCKET_NAME, - sizeof(addr.sun_path) - 1); - - if (connect(main_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) - throw WorkerException(std::string("connect: ") + strerror(errno)); - - WriteProtoMessage(main_fd, req); ClassifyResponse resp; - ReadProtoMessage(main_fd, resp); + grpc::ClientContext context; - spdlog::info("Domain '{}' classified as category '{}' with trust level {}", - domain, resp.categories(0), resp.trust_level()); + auto status = stub_->Classify(&context, req, &resp); + if (!status.ok()) { + throw WorkerException("Classify failed: " + status.error_message()); + } - close(main_fd); + std::string cat = resp.categories_size() > 0 ? resp.categories(0) : "unknown"; + spdlog::info("Domain '{}' classified as category '{}' with trust level {}", + domain, cat, resp.trust_level()); } catch (const std::exception &e) { - close(main_fd); SetState(WorkerState::ERROR); throw WorkerException(std::string("classifyDomain: ") + e.what()); } } void Worker::statsReport() { - int main_fd = 0; try { spdlog::info("Worker {} send stats", worker_id); - StatsReport req; - req.set_worker_id(worker_id); - main_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (main_fd < 0) - throw WorkerException(std::string("socket: ") + strerror(errno)); + StatsReport report; + report.set_worker_id(worker_id); + report.set_time(time(nullptr)); - sockaddr_un addr{}; - addr.sun_family = AF_UNIX; - strncpy(addr.sun_path, SOCKET_DIR STATS_SOCKET_NAME, - sizeof(addr.sun_path) - 1); + grpc::ClientContext context; + google::protobuf::Empty response; - if (connect(main_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) - throw WorkerException(std::string("connect: ") + strerror(errno)); + auto status = stub_->SendStats(&context, report, &response); + if (!status.ok()) { + throw WorkerException("SendStats failed: " + status.error_message()); + } - WriteProtoMessage(main_fd, req); - close(main_fd); + spdlog::info("Stats sent successfully"); } catch (const std::exception &e) { - close(main_fd); - SetState(WorkerState::ERROR); - throw WorkerException(std::string("statsReport: ") + e.what()); + spdlog::error("statsReport failed: {}", e.what()); } } Worker::Worker() : listener_fd(-1), state(WorkerState::BOOTING) { SendPulse(PULSE_REGISTER); + std::string controller_addr = "localhost:50051"; + if (const char* env_addr = getenv("CONTROLLER_GRPC_ADDR")) { + controller_addr = env_addr; + } + auto channel = grpc::CreateChannel(controller_addr, grpc::InsecureChannelCredentials()); + stub_ = DataService::NewStub(channel); + spdlog::info("gRPC channel created to {}", controller_addr); + socket_path = std::string(SOCKET_DIR) + std::to_string(worker_id) + ".sock"; unlink(socket_path.c_str()); From f550f60a43bce4a8f9c72fb31394bb83feec808a Mon Sep 17 00:00:00 2001 From: stepanrodimanov Date: Thu, 5 Mar 2026 10:41:29 +0300 Subject: [PATCH 8/8] update 6 --- admin/.env | 2 ++ admin/README.md | 6 +++--- admin/admin.py | 2 +- admin/admin_service.proto | 17 +++++++++++++++++ worker/include/worker.hpp | 4 ++-- worker/src/worker.cpp | 15 +++++++++------ 6 files changed, 34 insertions(+), 12 deletions(-) create mode 100644 admin/.env create mode 100644 admin/admin_service.proto diff --git a/admin/.env b/admin/.env new file mode 100644 index 0000000..d5bdf88 --- /dev/null +++ b/admin/.env @@ -0,0 +1,2 @@ +SERVER_HOST=localhost +SERVER_PORT=50051 \ No newline at end of file diff --git a/admin/README.md b/admin/README.md index 635c7da..1a96737 100644 --- a/admin/README.md +++ b/admin/README.md @@ -10,11 +10,11 @@ pip install -r requirements.txt python -m grpc_tools.protoc \ --python_out=. \ --grpc_python_out=. \ - -I ../controller/pkg/proto/admin_service \ - ../controller/pkg/proto/admin_service/admin_service.proto + -I . \ + admin_service.proto ``` -## Запуск контроллера (в отдельном терминале) +## Запуск контроллера ```bash cd ../controller bazel run //cmd/grpc_server:grpc_server diff --git a/admin/admin.py b/admin/admin.py index 613afc3..016ce10 100644 --- a/admin/admin.py +++ b/admin/admin.py @@ -6,7 +6,7 @@ import admin_service_pb2 import admin_service_pb2_grpc -load_dotenv(os.path.join(os.path.dirname(__file__), "..", "controller", ".env")) +load_dotenv(".env") class AdminClient: def __init__(self): diff --git a/admin/admin_service.proto b/admin/admin_service.proto new file mode 100644 index 0000000..c587721 --- /dev/null +++ b/admin/admin_service.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package admin_service; + +option go_package = "github.com/moevm/grpc_server/pkg/proto/admin_service"; + +service AdminService { + rpc LoadConfig(LoadConfigRequest) returns (LoadConfigResponse) {} +} + +message LoadConfigRequest { + bytes config_data = 1; +} + +message LoadConfigResponse { + bool success = 1; +} diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index 4ae473c..a727c9a 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -3,10 +3,10 @@ #include "communication.pb.h" +#include "communication.grpc.pb.h" #include +#include #include -#include -#include "communication.grpc.pb.h" #include diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index a5deeec..e596ee1 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -1,8 +1,10 @@ #include "../include/worker.hpp" +#include "communication.grpc.pb.h" #include #include #include +#include #include #include #include @@ -10,8 +12,6 @@ #include #include #include -#include -#include "communication.grpc.pb.h" void Worker::LogStateChange(WorkerState new_state) { const char *state_names[] = {"BOOTING", "FREE", "BUSY", "SHUTTING_DOWN", @@ -149,7 +149,8 @@ void Worker::requestPolicyFromController() { } catch (const std::exception &e) { SetState(WorkerState::ERROR); - throw WorkerException(std::string("requestPolicyFromController: ") + e.what()); + throw WorkerException(std::string("requestPolicyFromController: ") + + e.what()); } } @@ -169,7 +170,8 @@ void Worker::classifyDomain(const std::string &domain) { throw WorkerException("Classify failed: " + status.error_message()); } - std::string cat = resp.categories_size() > 0 ? resp.categories(0) : "unknown"; + std::string cat = + resp.categories_size() > 0 ? resp.categories(0) : "unknown"; spdlog::info("Domain '{}' classified as category '{}' with trust level {}", domain, cat, resp.trust_level()); @@ -206,10 +208,11 @@ Worker::Worker() : listener_fd(-1), state(WorkerState::BOOTING) { SendPulse(PULSE_REGISTER); std::string controller_addr = "localhost:50051"; - if (const char* env_addr = getenv("CONTROLLER_GRPC_ADDR")) { + if (const char *env_addr = getenv("CONTROLLER_GRPC_ADDR")) { controller_addr = env_addr; } - auto channel = grpc::CreateChannel(controller_addr, grpc::InsecureChannelCredentials()); + auto channel = + grpc::CreateChannel(controller_addr, grpc::InsecureChannelCredentials()); stub_ = DataService::NewStub(channel); spdlog::info("gRPC channel created to {}", controller_addr);