diff --git a/cli/sdk.go b/cli/sdk.go index aa545a2c5..672cf5001 100644 --- a/cli/sdk.go +++ b/cli/sdk.go @@ -15,12 +15,12 @@ var Verbose bool type CLI struct { agentSDK sdk.SDK - config grpc.Config + config grpc.AgentClientConfig client grpc.Client connectErr error } -func New(config grpc.Config) *CLI { +func New(config grpc.AgentClientConfig) *CLI { return &CLI{ config: config, } diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 1c7f69d29..ab0881543 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -97,14 +97,18 @@ func main() { svc := newService(ctx, logger, eventSvc, cfg, qp) - grpcServerConfig := server.Config{ - Port: cfg.AgentConfig.Port, - Host: cfg.AgentConfig.Host, - CertFile: cfg.AgentConfig.CertFile, - KeyFile: cfg.AgentConfig.KeyFile, - ServerCAFile: cfg.AgentConfig.ServerCAFile, - ClientCAFile: cfg.AgentConfig.ClientCAFile, - AttestedTLS: cfg.AgentConfig.AttestedTls, + agentGrpcServerConfig := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: cfg.AgentConfig.Host, + Port: cfg.AgentConfig.Port, + CertFile: cfg.AgentConfig.CertFile, + KeyFile: cfg.AgentConfig.KeyFile, + ServerCAFile: cfg.AgentConfig.ServerCAFile, + ClientCAFile: cfg.AgentConfig.ClientCAFile, + }, + }, + AttestedTLS: cfg.AgentConfig.AttestedTls, } registerAgentServiceServer := func(srv *grpc.Server) { @@ -119,7 +123,7 @@ func main() { return } - gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, qp, authSvc) + gs := grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, logger, qp, authSvc) g.Go(func() error { for { diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 63e765752..916e37032 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -91,7 +91,7 @@ func main() { return } - agentGRPCConfig := grpc.Config{} + agentGRPCConfig := grpc.AgentClientConfig{} if err := env.ParseWithOptions(&agentGRPCConfig, env.Options{Prefix: envPrefixAgentGRPC}); err != nil { message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err) rootCmd.Println(message) diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 226607f9c..bb9933d4c 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -25,7 +25,7 @@ import ( "github.com/ultravioletrs/cocos/manager/events" "github.com/ultravioletrs/cocos/manager/qemu" "github.com/ultravioletrs/cocos/manager/tracing" - "github.com/ultravioletrs/cocos/pkg/clients/grpc" + pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc" managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" @@ -92,7 +92,7 @@ func main() { args := qemuCfg.ConstructQemuArgs() logger.Info(strings.Join(args, " ")) - managerGRPCConfig := grpc.Config{} + managerGRPCConfig := pkggrpc.ManagerClientConfig{} if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil { logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) exitCode = 1 diff --git a/go.mod b/go.mod index 8fbd0b159..48c3b8c9c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/ultravioletrs/cocos -go 1.22.7 - -toolchain go1.23.1 +go 1.23.0 require ( github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746 diff --git a/internal/server/grpc/grpc.go b/internal/server/grpc/grpc.go index c3ce35a41..33ae890c1 100644 --- a/internal/server/grpc/grpc.go +++ b/internal/server/grpc/grpc.go @@ -60,8 +60,9 @@ type serviceRegister func(srv *grpc.Server) var _ server.Server = (*Server)(nil) -func New(ctx context.Context, cancel context.CancelFunc, name string, config server.Config, registerService serviceRegister, logger *slog.Logger, qp client.QuoteProvider, authSvc auth.Authenticator) server.Server { - listenFullAddress := fmt.Sprintf("%s:%s", config.Host, config.Port) +func New(ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, registerService serviceRegister, logger *slog.Logger, qp client.QuoteProvider, authSvc auth.Authenticator) server.Server { + base := config.GetBaseConfig() + listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port) return &Server{ BaseServer: server.BaseServer{ Ctx: ctx, @@ -91,101 +92,98 @@ func (s *Server) Start() error { creds := grpc.Creds(insecure.NewCredentials()) var listener net.Listener = nil + switch c := s.Config.(type) { + case server.AgentConfig: + switch { + case c.AttestedTLS: + certificateBytes, privateKeyBytes, err := generateCertificatesForATLS() + if err != nil { + return fmt.Errorf("failed to create certificate: %w", err) + } - switch { - case s.Config.AttestedTLS: - certificateBytes, privateKeyBytes, err := generateCertificatesForATLS() - if err != nil { - return fmt.Errorf("failed to create certificate: %w", err) - } - - certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes) - if err != nil { - return fmt.Errorf("falied due to invalid key pair: %w", err) - } - - tlsConfig := &tls.Config{ - ClientAuth: tls.NoClientCert, - Certificates: []tls.Certificate{certificate}, - } + certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes) + if err != nil { + return fmt.Errorf("falied due to invalid key pair: %w", err) + } - creds = grpc.Creds(credentials.NewTLS(tlsConfig)) + tlsConfig := &tls.Config{ + ClientAuth: tls.NoClientCert, + Certificates: []tls.Certificate{certificate}, + } - listener, err = atls.Listen( - s.Address, - certificateBytes, - privateKeyBytes, - ) - if err != nil { - return fmt.Errorf("failed to create Listener for aTLS: %w", err) - } - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address)) + creds = grpc.Creds(credentials.NewTLS(tlsConfig)) - case s.Config.CertFile != "" || s.Config.KeyFile != "": - certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile) - if err != nil { - return fmt.Errorf("failed to load auth certificates: %w", err) - } - tlsConfig := &tls.Config{ - ClientAuth: tls.NoClientCert, - Certificates: []tls.Certificate{certificate}, - } + listener, err = atls.Listen( + s.Address, + certificateBytes, + privateKeyBytes, + ) + if err != nil { + return fmt.Errorf("failed to create Listener for aTLS: %w", err) + } + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address)) - var mtlsCA string - // Loading Server CA file - rootCA, err := loadCertFile(s.Config.ServerCAFile) - if err != nil { - return fmt.Errorf("failed to load root ca file: %w", err) - } - if len(rootCA) > 0 { - if tlsConfig.RootCAs == nil { - tlsConfig.RootCAs = x509.NewCertPool() + case c.CertFile != "" || c.KeyFile != "": + certificate, err := loadX509KeyPair(c.CertFile, c.KeyFile) + if err != nil { + return fmt.Errorf("failed to load auth certificates: %w", err) } - if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) { - return fmt.Errorf("failed to append root ca to tls.Config") + tlsConfig := &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + Certificates: []tls.Certificate{certificate}, } - mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile) - } - // Loading Client CA File - clientCA, err := loadCertFile(s.Config.ClientCAFile) - if err != nil { - return fmt.Errorf("failed to load client ca file: %w", err) - } - if len(clientCA) > 0 { - if tlsConfig.ClientCAs == nil { - tlsConfig.ClientCAs = x509.NewCertPool() + var mtlsCA string + // Loading Server CA file + rootCA, err := loadCertFile(c.ServerCAFile) + if err != nil { + return fmt.Errorf("failed to load root ca file: %w", err) } - if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) { - return fmt.Errorf("failed to append client ca to tls.Config") + if len(rootCA) > 0 { + if tlsConfig.RootCAs == nil { + tlsConfig.RootCAs = x509.NewCertPool() + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) { + return fmt.Errorf("failed to append root ca to tls.Config") + } + mtlsCA = fmt.Sprintf("root ca %s", c.ServerCAFile) } - mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile) - } - if mtlsCA != "" { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - } + // Loading Client CA File + clientCA, err := loadCertFile(c.ClientCAFile) + if err != nil { + return fmt.Errorf("failed to load client ca file: %w", err) + } + if len(clientCA) > 0 { + if tlsConfig.ClientCAs == nil { + tlsConfig.ClientCAs = x509.NewCertPool() + } + if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) { + return fmt.Errorf("failed to append client ca to tls.Config") + } + mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, c.ClientCAFile) + } + creds = grpc.Creds(credentials.NewTLS(tlsConfig)) + switch { + case mtlsCA != "": + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, c.CertFile, c.KeyFile, mtlsCA)) + default: + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, c.CertFile, c.KeyFile)) + } - creds = grpc.Creds(credentials.NewTLS(tlsConfig)) - switch { - case mtlsCA != "": - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS", s.Name, s.Address)) + listener, err = net.Listen("tcp", s.Address) + if err != nil { + return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) + } default: - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS", s.Name, s.Address)) - } + var err error - listener, err = net.Listen("tcp", s.Address) - if err != nil { - return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) - } - default: - var err error - - listener, err = net.Listen("tcp", s.Address) - if err != nil { - return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) + listener, err = net.Listen("tcp", s.Address) + if err != nil { + return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) + } + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address)) } - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address)) } grpcServerOptions = append(grpcServerOptions, creds) diff --git a/internal/server/grpc/grpc_test.go b/internal/server/grpc/grpc_test.go index 248980e4a..b5bcbaebe 100644 --- a/internal/server/grpc/grpc_test.go +++ b/internal/server/grpc/grpc_test.go @@ -38,9 +38,13 @@ func TestNew(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - config := server.Config{ - Host: "localhost", - Port: "50051", + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "50051", + }, + }, } logger := slog.Default() qp := new(mocks.QuoteProvider) @@ -80,11 +84,15 @@ func TestServerStartWithTLSFile(t *testing.T) { err = keyFile.Close() assert.NoError(t, err) - config := server.Config{ - Host: "localhost", - Port: "0", - CertFile: certFile.Name(), - KeyFile: keyFile.Name(), + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + CertFile: certFile.Name(), + KeyFile: keyFile.Name(), + }, + }, } logBuffer := &ThreadSafeBuffer{} @@ -119,38 +127,19 @@ func TestServerStartWithTLSFile(t *testing.T) { func TestServerStartWithmTLSFile(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - cert, key, err := generateSelfSignedCert() - assert.NoError(t, err) - - certFile, err := os.CreateTemp("", "cert*.pem") - assert.NoError(t, err) - - keyFile, err := os.CreateTemp("", "key*.pem") + caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles() assert.NoError(t, err) - t.Cleanup(func() { - os.Remove(certFile.Name()) - os.Remove(keyFile.Name()) - }) - - _, err = certFile.Write(cert) - assert.NoError(t, err) - - _, err = keyFile.Write(key) - assert.NoError(t, err) - - err = certFile.Close() - assert.NoError(t, err) - err = keyFile.Close() - assert.NoError(t, err) - - config := server.Config{ - Host: "localhost", - Port: "0", - CertFile: certFile.Name(), - KeyFile: keyFile.Name(), - ServerCAFile: certFile.Name(), - ClientCAFile: certFile.Name(), + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + CertFile: string(clientCertFile), + KeyFile: string(clientKeyFile), + ServerCAFile: caCertFile, + }, + }, } logBuffer := &ThreadSafeBuffer{} @@ -185,9 +174,13 @@ func TestServerStartWithmTLSFile(t *testing.T) { func TestServerStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - config := server.Config{ - Host: "localhost", - Port: "0", + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, } buf := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug})) @@ -268,54 +261,74 @@ func (b *ThreadSafeBuffer) String() string { func TestServerInitializationAndStartup(t *testing.T) { testCases := []struct { name string - config server.Config + config server.AgentConfig expectedLog string expectError bool - setupCallback func(*testing.T, *server.Config, *ThreadSafeBuffer) + setupCallback func(*testing.T, *server.AgentConfig, *ThreadSafeBuffer) }{ { name: "Non-TLS Server Startup", - config: server.Config{ - Host: "localhost", - Port: "0", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, }, expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS", }, { name: "TLS Server Startup with Self-Signed Certificate", - config: server.Config{ - Host: "localhost", - Port: "0", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, }, setupCallback: setupTLSConfig, expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS", }, { name: "TLS Server Startup with Invalid Certificates", - config: server.Config{ - Host: "localhost", - Port: "0", - CertFile: "invalid", - KeyFile: "invalid", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + CertFile: "invalid", + KeyFile: "invalid", + }, + }, }, expectError: true, expectedLog: "failed to load auth certificates", }, { name: "mTLS Server Startup", - config: server.Config{ - Host: "localhost", - Port: "0", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, }, setupCallback: setupMTLSConfig, expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS", }, { name: "mTLS Server Startup with Invalid Root CA", - config: server.Config{ - Host: "localhost", - Port: "0", - ServerCAFile: "invalid", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + ServerCAFile: "invalid", + }, + }, }, setupCallback: setupInvalidRootCAConfig, expectError: true, @@ -323,10 +336,14 @@ func TestServerInitializationAndStartup(t *testing.T) { }, { name: "mTLS Server Startup with Invalid Client CA", - config: server.Config{ - Host: "localhost", - Port: "0", - ServerCAFile: "invalid", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + ServerCAFile: "invalid", + }, + }, }, setupCallback: setupInvalidClientCAConfig, expectError: true, @@ -334,9 +351,13 @@ func TestServerInitializationAndStartup(t *testing.T) { }, { name: "Attested TLS Server Startup", - config: server.Config{ - Host: "localhost", - Port: "0", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, AttestedTLS: true, }, expectedLog: "TestServer service gRPC server listening at localhost:0 with Attested TLS", @@ -347,7 +368,6 @@ func TestServerInitializationAndStartup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if tc.setupCallback != nil { tc.setupCallback(t, &tc.config, nil) } @@ -358,7 +378,6 @@ func TestServerInitializationAndStartup(t *testing.T) { authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, qp, authSvc) - var wg sync.WaitGroup wg.Add(1) @@ -390,7 +409,7 @@ func TestServerInitializationAndStartup(t *testing.T) { } } -func setupTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { +func setupTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { cert, key, err := generateSelfSignedCert() assert.NoError(t, err) @@ -398,7 +417,7 @@ func setupTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { config.KeyFile = string(key) } -func setupMTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { +func setupMTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { cert, key, err := generateSelfSignedCert() assert.NoError(t, err) @@ -408,7 +427,7 @@ func setupMTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { config.ClientCAFile = string(cert) } -func setupInvalidRootCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { +func setupInvalidRootCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { cert, key, err := generateSelfSignedCert() assert.NoError(t, err) @@ -418,7 +437,7 @@ func setupInvalidRootCAConfig(t *testing.T, config *server.Config, _ *ThreadSafe config.ClientCAFile = string(cert) } -func setupInvalidClientCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { +func setupInvalidClientCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { cert, key, err := generateSelfSignedCert() assert.NoError(t, err) @@ -427,3 +446,89 @@ func setupInvalidClientCAConfig(t *testing.T, config *server.Config, _ *ThreadSa config.ClientCAFile = "invalid" config.ServerCAFile = string(cert) } + +func createCertificatesFiles() (string, string, string, error) { + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", "", err + } + + caTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return "", "", "", err + } + + caCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})) + if err != nil { + return "", "", "", err + } + + clientKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", "", err + } + + clientTemplate := x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey) + if err != nil { + return "", "", "", err + } + + clientCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})) + if err != nil { + return "", "", "", err + } + + clientKeyFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)})) + if err != nil { + return "", "", "", err + } + + return caCertFile, clientCertFile, clientKeyFile, nil +} + +func createTempFile(data []byte) (string, error) { + file, err := createTempFileHandle() + if err != nil { + return "", err + } + + _, err = file.Write(data) + if err != nil { + return "", err + } + + err = file.Close() + if err != nil { + return "", err + } + + return file.Name(), nil +} + +func createTempFileHandle() (*os.File, error) { + return os.CreateTemp("", "test") +} diff --git a/internal/server/server.go b/internal/server/server.go index c50e9a11a..188c24957 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -16,14 +16,25 @@ type Server interface { Stop() error } -type Config struct { - Host string `env:"HOST" envDefault:""` - Port string `env:"PORT" envDefault:""` +type ServerConfiguration interface { + GetBaseConfig() ServerConfig +} + +type BaseConfig struct { + Host string `env:"HOST" envDefault:"localhost"` + Port string `env:"PORT" envDefault:"7001"` + ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` CertFile string `env:"SERVER_CERT" envDefault:""` KeyFile string `env:"SERVER_KEY" envDefault:""` - ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` ClientCAFile string `env:"CLIENT_CA_CERTS" envDefault:""` - AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` +} + +type ServerConfig struct { + BaseConfig +} +type AgentConfig struct { + ServerConfig + AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` } type BaseServer struct { @@ -31,11 +42,19 @@ type BaseServer struct { Cancel context.CancelFunc Name string Address string - Config Config + Config ServerConfiguration Logger *slog.Logger Protocol string } +func (s ServerConfig) GetBaseConfig() ServerConfig { + return s +} + +func (a AgentConfig) GetBaseConfig() ServerConfig { + return a.ServerConfig +} + func stopAllServer(servers ...Server) error { var errs []error for _, server := range servers { diff --git a/pkg/clients/grpc/agent/agent.go b/pkg/clients/grpc/agent/agent.go index 9e514927a..4880a2c28 100644 --- a/pkg/clients/grpc/agent/agent.go +++ b/pkg/clients/grpc/agent/agent.go @@ -14,7 +14,7 @@ import ( var ErrAgentServiceUnavailable = errors.New("agent service is unavailable") // NewAgentClient creates new agent gRPC client instance. -func NewAgentClient(ctx context.Context, cfg grpc.Config) (grpc.Client, agent.AgentServiceClient, error) { +func NewAgentClient(ctx context.Context, cfg grpc.AgentClientConfig) (grpc.Client, agent.AgentServiceClient, error) { client, err := grpc.NewClient(cfg) if err != nil { return nil, nil, err diff --git a/pkg/clients/grpc/agent/agent_test.go b/pkg/clients/grpc/agent/agent_test.go index 539e474df..ea6b8402f 100644 --- a/pkg/clients/grpc/agent/agent_test.go +++ b/pkg/clients/grpc/agent/agent_test.go @@ -78,32 +78,38 @@ func TestAgentClientIntegration(t *testing.T) { tests := []struct { name string serverRunning bool - config pkggrpc.Config + config pkggrpc.AgentClientConfig err error }{ { name: "successful connection", serverRunning: true, - config: pkggrpc.Config{ - URL: testServer.listenAddr, - Timeout: 1, + config: pkggrpc.AgentClientConfig{ + BaseConfig: pkggrpc.BaseConfig{ + URL: testServer.listenAddr, + Timeout: 1, + }, }, err: nil, }, { name: "server not healthy", serverRunning: false, - config: pkggrpc.Config{ - URL: "", - Timeout: 1, + config: pkggrpc.AgentClientConfig{ + BaseConfig: pkggrpc.BaseConfig{ + URL: "", + Timeout: 1, + }, }, err: ErrAgentServiceUnavailable, }, { name: "invalid config, missing AttestationPolicy with aTLS", - config: pkggrpc.Config{ - URL: testServer.listenAddr, - Timeout: 1, + config: pkggrpc.AgentClientConfig{ + BaseConfig: pkggrpc.BaseConfig{ + URL: testServer.listenAddr, + Timeout: 1, + }, AttestedTLS: true, }, err: pkggrpc.ErrAttestationPolicyMissing, diff --git a/pkg/clients/grpc/connect.go b/pkg/clients/grpc/connect.go index c091bfbf0..aeb24db9a 100644 --- a/pkg/clients/grpc/connect.go +++ b/pkg/clients/grpc/connect.go @@ -50,14 +50,38 @@ var ( errFailedToLoadRootCA = errors.New("failed to load root ca file") ) -type Config struct { - ClientCert string `env:"CLIENT_CERT" envDefault:""` - ClientKey string `env:"CLIENT_KEY" envDefault:""` - ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` - URL string `env:"URL" envDefault:"localhost:7001"` - Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"` - AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` - AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""` +type ClientConfiguration interface { + GetBaseConfig() BaseConfig +} + +type BaseConfig struct { + URL string `env:"URL" envDefault:"localhost:7001"` + Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"` + ClientCert string `env:"CLIENT_CERT" envDefault:""` + ClientKey string `env:"CLIENT_KEY" envDefault:""` + ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` +} + +type AgentClientConfig struct { + BaseConfig + AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""` + AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` +} + +type ManagerClientConfig struct { + BaseConfig +} + +func (a BaseConfig) GetBaseConfig() BaseConfig { + return a +} + +func (a AgentClientConfig) GetBaseConfig() BaseConfig { + return a.BaseConfig +} + +func (a ManagerClientConfig) GetBaseConfig() BaseConfig { + return a.BaseConfig } type Client interface { @@ -73,13 +97,13 @@ type Client interface { type client struct { *grpc.ClientConn - cfg Config + cfg ClientConfiguration secure security } var _ Client = (*client)(nil) -func NewClient(cfg Config) (Client, error) { +func NewClient(cfg ClientConfiguration) (Client, error) { conn, secure, err := connect(cfg) if err != nil { return nil, err @@ -120,15 +144,15 @@ func (c *client) Connection() *grpc.ClientConn { } // connect creates new gRPC client and connect to gRPC server. -func connect(cfg Config) (*grpc.ClientConn, security, error) { +func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) { opts := []grpc.DialOption{ grpc.WithStatsHandler(otelgrpc.NewClientHandler()), } secure := withoutTLS - tc := insecure.NewCredentials() + var tc credentials.TransportCredentials - if cfg.AttestedTLS { - err := ReadAttestationPolicy(cfg.AttestationPolicy, "eprovider.AttConfigurationSEVSNP) + if agcfg, ok := cfg.(AgentClientConfig); ok && agcfg.AttestedTLS { + err := ReadAttestationPolicy(agcfg.AttestationPolicy, "eprovider.AttConfigurationSEVSNP) if err != nil { return nil, secure, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err) } @@ -141,46 +165,60 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) { opts = append(opts, grpc.WithContextDialer(CustomDialer)) secure = withaTLS } else { - if cfg.ServerCAFile != "" { - tlsConfig := &tls.Config{} - - // Loading root ca certificates file - rootCA, err := os.ReadFile(cfg.ServerCAFile) - if err != nil { - return nil, secure, errors.Wrap(errFailedToLoadRootCA, err) - } - if len(rootCA) > 0 { - capool := x509.NewCertPool() - if !capool.AppendCertsFromPEM(rootCA) { - return nil, secure, fmt.Errorf("failed to append root ca to tls.Config") - } - tlsConfig.RootCAs = capool - secure = withTLS - } - - // Loading mTLS certificates file - if cfg.ClientCert != "" || cfg.ClientKey != "" { - certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey) - if err != nil { - return nil, secure, errors.Wrap(errFailedToLoadClientCertKey, err) - } - tlsConfig.Certificates = []tls.Certificate{certificate} - secure = withmTLS - } - - tc = credentials.NewTLS(tlsConfig) + conf := cfg.GetBaseConfig() + transportCreds, err, sec := loadTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey) + if err != nil { + return nil, secure, err } + tc = transportCreds + secure = sec } opts = append(opts, grpc.WithTransportCredentials(tc)) - conn, err := grpc.NewClient(cfg.URL, opts...) + conn, err := grpc.NewClient(cfg.GetBaseConfig().URL, opts...) if err != nil { return nil, secure, errors.Wrap(errGrpcConnect, err) } return conn, secure, nil } +func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, error, security) { + tlsConfig := &tls.Config{} + secure := withoutTLS + tc := insecure.NewCredentials() + + // Load Root CA certificates + if serverCAFile != "" { + rootCA, err := os.ReadFile(serverCAFile) + if err != nil { + return nil, errors.Wrap(errFailedToLoadRootCA, err), secure + } + if len(rootCA) > 0 { + capool := x509.NewCertPool() + if !capool.AppendCertsFromPEM(rootCA) { + return nil, fmt.Errorf("failed to append root ca to tls.Config"), secure + } + tlsConfig.RootCAs = capool + secure = withTLS + tc = credentials.NewTLS(tlsConfig) + } + } + + // Load mTLS certificates + if clientCert != "" || clientKey != "" { + certificate, err := tls.LoadX509KeyPair(clientCert, clientKey) + if err != nil { + return nil, errors.Wrap(errFailedToLoadClientCertKey, err), secure + } + tlsConfig.Certificates = []tls.Certificate{certificate} + secure = withmTLS + tc = credentials.NewTLS(tlsConfig) + } + + return tc, nil, secure +} + func ReadAttestationPolicy(manifestPath string, attestationConfiguration *check.Config) error { if manifestPath != "" { manifest, err := os.ReadFile(manifestPath) diff --git a/pkg/clients/grpc/connect_test.go b/pkg/clients/grpc/connect_test.go index fd3ed29e4..863c96546 100644 --- a/pkg/clients/grpc/connect_test.go +++ b/pkg/clients/grpc/connect_test.go @@ -11,6 +11,7 @@ import ( "fmt" "math/big" "os" + "strings" "testing" "time" @@ -31,14 +32,15 @@ func TestNewClient(t *testing.T) { }) tests := []struct { - name string - cfg Config - wantErr bool - err error + name string + cfg BaseConfig + agentCfg AgentClientConfig + wantErr bool + err error }{ { name: "Success without TLS", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", }, wantErr: false, @@ -46,7 +48,7 @@ func TestNewClient(t *testing.T) { }, { name: "Success with TLS", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: caCertFile, }, @@ -55,7 +57,7 @@ func TestNewClient(t *testing.T) { }, { name: "Success with mTLS", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: caCertFile, ClientCert: clientCertFile, @@ -64,9 +66,52 @@ func TestNewClient(t *testing.T) { wantErr: false, err: nil, }, + { + name: "Success agent client with mTLS", + agentCfg: AgentClientConfig{ + BaseConfig: BaseConfig{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + }, + wantErr: false, + err: nil, + }, + { + name: "Success agent client with aTLS", + agentCfg: AgentClientConfig{ + BaseConfig: BaseConfig{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + AttestedTLS: true, + AttestationPolicy: "../../../scripts/attestation_policy/attestation_policy.json", + }, + wantErr: false, + err: nil, + }, + { + name: "Failed agent client with aTLS", + agentCfg: AgentClientConfig{ + BaseConfig: BaseConfig{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + AttestedTLS: true, + AttestationPolicy: "no such file", + }, + wantErr: true, + err: fmt.Errorf("failed to read Attestation Policy"), + }, { name: "Fail with invalid ServerCAFile", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: "nonexistent.pem", }, @@ -75,7 +120,7 @@ func TestNewClient(t *testing.T) { }, { name: "Fail with invalid ClientCert", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: caCertFile, ClientCert: "nonexistent.pem", @@ -86,7 +131,7 @@ func TestNewClient(t *testing.T) { }, { name: "Fail with invalid ClientKey", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: caCertFile, ClientCert: clientCertFile, @@ -99,7 +144,12 @@ func TestNewClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client, err := NewClient(tt.cfg) + var client Client + if strings.Contains(tt.name, "agent client") { + client, err = NewClient(tt.agentCfg) + } else { + client, err = NewClient(tt.cfg) + } assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) if tt.wantErr { assert.Error(t, err) diff --git a/pkg/clients/grpc/manager/manager.go b/pkg/clients/grpc/manager/manager.go index 3f9373876..736796bc8 100644 --- a/pkg/clients/grpc/manager/manager.go +++ b/pkg/clients/grpc/manager/manager.go @@ -8,7 +8,7 @@ import ( ) // NewManagerClient creates new manager gRPC client instance. -func NewManagerClient(cfg grpc.Config) (grpc.Client, manager.ManagerServiceClient, error) { +func NewManagerClient(cfg grpc.ManagerClientConfig) (grpc.Client, manager.ManagerServiceClient, error) { client, err := grpc.NewClient(cfg) if err != nil { return nil, nil, err diff --git a/pkg/clients/grpc/manager/manager_test.go b/pkg/clients/grpc/manager/manager_test.go index 2ee2592f4..49ae5627f 100644 --- a/pkg/clients/grpc/manager/manager_test.go +++ b/pkg/clients/grpc/manager/manager_test.go @@ -13,21 +13,18 @@ import ( func TestNewManagerClient(t *testing.T) { tests := []struct { name string - cfg grpc.Config + cfg grpc.ManagerClientConfig err error }{ { name: "Valid config", - cfg: grpc.Config{ - URL: "localhost:7001", + cfg: grpc.ManagerClientConfig{ + BaseConfig: grpc.BaseConfig{ + URL: "localhost:7001", + }, }, err: nil, }, - { - name: "invalid config, missing AttestationPolicy with aTLS", - cfg: grpc.Config{AttestedTLS: true}, - err: grpc.ErrAttestationPolicyMissing, - }, } for _, tt := range tests { diff --git a/test/computations/main.go b/test/computations/main.go index b2e3dda2b..a45fa7eeb 100644 --- a/test/computations/main.go +++ b/test/computations/main.go @@ -130,7 +130,11 @@ func main() { reflection.Register(srv) manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(incomingChan, &svc{logger: logger})) } - grpcServerConfig := server.Config{Port: defaultPort} + grpcServerConfig := server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Port: defaultPort, + }, + } if err := env.ParseWithOptions(&grpcServerConfig, env.Options{}); err != nil { logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) return