From d669d381c8ef5102e565b3ad6bc20fe35272b961 Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Wed, 15 Oct 2025 11:30:06 -0700 Subject: [PATCH 1/3] filtering TLS connections based on the subject name from Caller --- cns/configuration/cns_config.json | 3 +- cns/configuration/configuration.go | 1 + cns/configuration/configuration_test.go | 9 +- cns/service.go | 26 +++++- cns/service/main.go | 1 + cns/service_test.go | 110 +++++++++++++++--------- server/tls/tlscertificate_retriever.go | 1 + 7 files changed, 104 insertions(+), 47 deletions(-) diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index 81ef6c9b05..1b17fdc4ab 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -35,5 +35,6 @@ "AZRSettings": { "PopulateHomeAzCacheRetryIntervalSecs": 60 }, - "MinTLSVersion": "TLS 1.2" + "MinTLSVersion": "TLS 1.2", + "AllowedClientSubjectName": "" } diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index 9ec5f8664f..c7bbe0ee48 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -59,6 +59,7 @@ type CNSConfig struct { WireserverIP string GRPCSettings GRPCSettings MinTLSVersion string + AllowedClientSubjectName string } type TelemetrySettings struct { diff --git a/cns/configuration/configuration_test.go b/cns/configuration/configuration_test.go index 186c92c376..7aa230d788 100644 --- a/cns/configuration/configuration_test.go +++ b/cns/configuration/configuration_test.go @@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "localhost", Port: 8080, }, - MinTLSVersion: "TLS 1.2", + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "", }, }, { @@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", + MinTLSVersion: "TLS 1.3", + AllowedClientSubjectName: "example.com", }, want: CNSConfig{ ChannelMode: "Other", @@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", + MinTLSVersion: "TLS 1.3", + AllowedClientSubjectName: "example.com", }, }, } diff --git a/cns/service.go b/cns/service.go index ab7a0be3c3..5e8403cc22 100644 --- a/cns/service.go +++ b/cns/service.go @@ -156,6 +156,25 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings) } +// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. +func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error { + // no client subject name provided, skip verification + if clientSubjectName == "" { + return nil + } + + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return errors.Errorf("failed to parse certificate: %v", err) + } + + err = cert.VerifyHostname(clientSubjectName) + if err != nil { + return errors.Errorf("failed to verify client certificate hostname: %v", err) + } + return nil +} + func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) { tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings) if err != nil { @@ -202,8 +221,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName) + } } - logger.Debugf("TLS configured successfully from file: %+v", tlsSettings) return tlsConfig, nil @@ -254,6 +275,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName) + } } logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings) diff --git a/cns/service/main.go b/cns/service/main.go index 67f7872f44..e4f207f6e0 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -810,6 +810,7 @@ func main() { KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour, UseMTLS: cnsconfig.UseMTLS, MinTLSVersion: cnsconfig.MinTLSVersion, + AllowedClientSubjectName: cnsconfig.AllowedClientSubjectName, } } diff --git a/cns/service_test.go b/cns/service_test.go index d20c2ef11a..8da4333e00 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -12,6 +12,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" "math/big" "net/http" "os" @@ -133,57 +134,82 @@ func TestNewService(t *testing.T) { t.Run("NewServiceWithMutualTLS", func(t *testing.T) { testCertFilePath := createTestCertificate(t) - config.TLSSettings = serverTLS.TlsSettings{ - TLSPort: "10091", - TLSSubjectName: "localhost", - TLSCertificatePath: testCertFilePath, - UseMTLS: true, - MinTLSVersion: "TLS 1.2", + TLSSetting := serverTLS.TlsSettings{ + TLSPort: "10091", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "example.com", } - svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) - require.NoError(t, err) - require.IsType(t, &Service{}, svc) + TLSSettingWithDisallowedClientSN := serverTLS.TlsSettings{ + TLSPort: "10092", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "random.com", + } - svc.SetOption(acn.OptCnsURL, "") - svc.SetOption(acn.OptCnsPort, "") + runMutualTLSTest := func(tlsSettings serverTLS.TlsSettings, handshakeFailureExpected bool) { + config.TLSSettings = tlsSettings + svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) + require.NoError(t, err) + require.IsType(t, &Service{}, svc) - err = svc.Initialize(config) - t.Cleanup(func() { - svc.Uninitialize() - }) - require.NoError(t, err) + svc.SetOption(acn.OptCnsURL, "") + svc.SetOption(acn.OptCnsPort, "") - err = svc.StartListener(config) - require.NoError(t, err) + err = svc.Initialize(config) + require.NoError(t, err) - mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) - require.NoError(t, err) + err = svc.StartListener(config) + require.NoError(t, err) - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: mTLSConfig, - }, - } + mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) + require.NoError(t, err) - // TLS listener - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody) - require.NoError(t, err) - resp, err := client.Do(req) - t.Cleanup(func() { - resp.Body.Close() - }) - require.NoError(t, err) + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: mTLSConfig, + }, + } - // HTTP listener - httpClient := &http.Client{} - req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) - require.NoError(t, err) - resp, err = httpClient.Do(req) - t.Cleanup(func() { - resp.Body.Close() - }) - require.NoError(t, err) + tlsUrl := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort) + // TLS listener + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsUrl, http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + t.Cleanup(func() { + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + }) + if handshakeFailureExpected { + require.Error(t, err) + require.ErrorContains(t, err, "failed to verify client certificate hostname") + + } else { + require.NoError(t, err) + } + + // HTTP listener + httpClient := &http.Client{} + req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) + require.NoError(t, err) + resp, err = httpClient.Do(req) + t.Cleanup(func() { + resp.Body.Close() + }) + require.NoError(t, err) + + // Cleanup + svc.Uninitialize() + + } + runMutualTLSTest(TLSSetting, false) + runMutualTLSTest(TLSSettingWithDisallowedClientSN, true) }) } diff --git a/server/tls/tlscertificate_retriever.go b/server/tls/tlscertificate_retriever.go index a22a7336b7..dbd65af6ca 100644 --- a/server/tls/tlscertificate_retriever.go +++ b/server/tls/tlscertificate_retriever.go @@ -15,6 +15,7 @@ type TlsSettings struct { KeyVaultCertificateRefreshInterval time.Duration UseMTLS bool MinTLSVersion string + AllowedClientSubjectName string } func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) { From fddda40063e1c1e0e95e6e430315aff2aa9df5b4 Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Wed, 15 Oct 2025 12:44:18 -0700 Subject: [PATCH 2/3] add validation to client rawCerts --- cns/service.go | 4 ++++ cns/service_test.go | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cns/service.go b/cns/service.go index 5e8403cc22..f28a8fa7b3 100644 --- a/cns/service.go +++ b/cns/service.go @@ -158,6 +158,10 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. // verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error { + + if len(rawCerts) == 0 { + return errors.New("no client certificate provided") + } // no client subject name provided, skip verification if clientSubjectName == "" { return nil diff --git a/cns/service_test.go b/cns/service_test.go index 8da4333e00..58f1cd6b01 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -176,9 +176,9 @@ func TestNewService(t *testing.T) { }, } - tlsUrl := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort) + tlsURL := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort) // TLS listener - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsUrl, http.NoBody) + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody) require.NoError(t, err) resp, err := client.Do(req) t.Cleanup(func() { From eb0bbdd0c46afb8bfd01922662ec5b2571ad0a43 Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Wed, 15 Oct 2025 14:28:31 -0700 Subject: [PATCH 3/3] fix lint --- cns/service.go | 1 - cns/service_test.go | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/cns/service.go b/cns/service.go index f28a8fa7b3..d879ef0c4b 100644 --- a/cns/service.go +++ b/cns/service.go @@ -158,7 +158,6 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. // verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error { - if len(rawCerts) == 0 { return errors.New("no client certificate provided") } diff --git a/cns/service_test.go b/cns/service_test.go index 58f1cd6b01..0b8b362d2d 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -12,7 +12,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "fmt" "math/big" "net/http" "os" @@ -176,7 +175,7 @@ func TestNewService(t *testing.T) { }, } - tlsURL := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort) + tlsURL := "https://localhost:" + tlsSettings.TLSPort // TLS listener req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody) require.NoError(t, err) @@ -206,7 +205,6 @@ func TestNewService(t *testing.T) { // Cleanup svc.Uninitialize() - } runMutualTLSTest(TLSSetting, false) runMutualTLSTest(TLSSettingWithDisallowedClientSN, true)