diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index 81ef6c9b05..1b17fdc4ab 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -35,5 +35,6 @@ "AZRSettings": { "PopulateHomeAzCacheRetryIntervalSecs": 60 }, - "MinTLSVersion": "TLS 1.2" + "MinTLSVersion": "TLS 1.2", + "AllowedClientSubjectName": "" } diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index 9ec5f8664f..c7bbe0ee48 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -59,6 +59,7 @@ type CNSConfig struct { WireserverIP string GRPCSettings GRPCSettings MinTLSVersion string + AllowedClientSubjectName string } type TelemetrySettings struct { diff --git a/cns/configuration/configuration_test.go b/cns/configuration/configuration_test.go index 186c92c376..7aa230d788 100644 --- a/cns/configuration/configuration_test.go +++ b/cns/configuration/configuration_test.go @@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "localhost", Port: 8080, }, - MinTLSVersion: "TLS 1.2", + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "", }, }, { @@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", + MinTLSVersion: "TLS 1.3", + AllowedClientSubjectName: "example.com", }, want: CNSConfig{ ChannelMode: "Other", @@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", + MinTLSVersion: "TLS 1.3", + AllowedClientSubjectName: "example.com", }, }, } diff --git a/cns/service.go b/cns/service.go index ab7a0be3c3..d879ef0c4b 100644 --- a/cns/service.go +++ b/cns/service.go @@ -156,6 +156,28 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings) } +// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. +func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error { + if len(rawCerts) == 0 { + return errors.New("no client certificate provided") + } + // no client subject name provided, skip verification + if clientSubjectName == "" { + return nil + } + + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return errors.Errorf("failed to parse certificate: %v", err) + } + + err = cert.VerifyHostname(clientSubjectName) + if err != nil { + return errors.Errorf("failed to verify client certificate hostname: %v", err) + } + return nil +} + func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) { tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings) if err != nil { @@ -202,8 +224,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName) + } } - logger.Debugf("TLS configured successfully from file: %+v", tlsSettings) return tlsConfig, nil @@ -254,6 +278,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName) + } } logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings) diff --git a/cns/service/main.go b/cns/service/main.go index 67f7872f44..e4f207f6e0 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -810,6 +810,7 @@ func main() { KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour, UseMTLS: cnsconfig.UseMTLS, MinTLSVersion: cnsconfig.MinTLSVersion, + AllowedClientSubjectName: cnsconfig.AllowedClientSubjectName, } } diff --git a/cns/service_test.go b/cns/service_test.go index d20c2ef11a..0b8b362d2d 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -133,57 +133,81 @@ func TestNewService(t *testing.T) { t.Run("NewServiceWithMutualTLS", func(t *testing.T) { testCertFilePath := createTestCertificate(t) - config.TLSSettings = serverTLS.TlsSettings{ - TLSPort: "10091", - TLSSubjectName: "localhost", - TLSCertificatePath: testCertFilePath, - UseMTLS: true, - MinTLSVersion: "TLS 1.2", + TLSSetting := serverTLS.TlsSettings{ + TLSPort: "10091", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "example.com", } - svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) - require.NoError(t, err) - require.IsType(t, &Service{}, svc) + TLSSettingWithDisallowedClientSN := serverTLS.TlsSettings{ + TLSPort: "10092", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "random.com", + } - svc.SetOption(acn.OptCnsURL, "") - svc.SetOption(acn.OptCnsPort, "") + runMutualTLSTest := func(tlsSettings serverTLS.TlsSettings, handshakeFailureExpected bool) { + config.TLSSettings = tlsSettings + svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) + require.NoError(t, err) + require.IsType(t, &Service{}, svc) - err = svc.Initialize(config) - t.Cleanup(func() { - svc.Uninitialize() - }) - require.NoError(t, err) + svc.SetOption(acn.OptCnsURL, "") + svc.SetOption(acn.OptCnsPort, "") - err = svc.StartListener(config) - require.NoError(t, err) + err = svc.Initialize(config) + require.NoError(t, err) - mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) - require.NoError(t, err) + err = svc.StartListener(config) + require.NoError(t, err) - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: mTLSConfig, - }, - } + mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) + require.NoError(t, err) - // TLS listener - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody) - require.NoError(t, err) - resp, err := client.Do(req) - t.Cleanup(func() { - resp.Body.Close() - }) - require.NoError(t, err) + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: mTLSConfig, + }, + } - // HTTP listener - httpClient := &http.Client{} - req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) - require.NoError(t, err) - resp, err = httpClient.Do(req) - t.Cleanup(func() { - resp.Body.Close() - }) - require.NoError(t, err) + tlsURL := "https://localhost:" + tlsSettings.TLSPort + // TLS listener + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + t.Cleanup(func() { + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + }) + if handshakeFailureExpected { + require.Error(t, err) + require.ErrorContains(t, err, "failed to verify client certificate hostname") + + } else { + require.NoError(t, err) + } + + // HTTP listener + httpClient := &http.Client{} + req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) + require.NoError(t, err) + resp, err = httpClient.Do(req) + t.Cleanup(func() { + resp.Body.Close() + }) + require.NoError(t, err) + + // Cleanup + svc.Uninitialize() + } + runMutualTLSTest(TLSSetting, false) + runMutualTLSTest(TLSSettingWithDisallowedClientSN, true) }) } diff --git a/server/tls/tlscertificate_retriever.go b/server/tls/tlscertificate_retriever.go index a22a7336b7..dbd65af6ca 100644 --- a/server/tls/tlscertificate_retriever.go +++ b/server/tls/tlscertificate_retriever.go @@ -15,6 +15,7 @@ type TlsSettings struct { KeyVaultCertificateRefreshInterval time.Duration UseMTLS bool MinTLSVersion string + AllowedClientSubjectName string } func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {