Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cns/configuration/cns_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@
"AZRSettings": {
"PopulateHomeAzCacheRetryIntervalSecs": 60
},
"MinTLSVersion": "TLS 1.2"
"MinTLSVersion": "TLS 1.2",
"AllowedClientSubjectName": ""
}
1 change: 1 addition & 0 deletions cns/configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type CNSConfig struct {
WireserverIP string
GRPCSettings GRPCSettings
MinTLSVersion string
AllowedClientSubjectName string
}

type TelemetrySettings struct {
Expand Down
9 changes: 6 additions & 3 deletions cns/configuration/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "localhost",
Port: 8080,
},
MinTLSVersion: "TLS 1.2",
MinTLSVersion: "TLS 1.2",
AllowedClientSubjectName: "",
},
},
{
Expand Down Expand Up @@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "192.168.1.1",
Port: 9090,
},
MinTLSVersion: "TLS 1.3",
MinTLSVersion: "TLS 1.3",
AllowedClientSubjectName: "example.com",
},
want: CNSConfig{
ChannelMode: "Other",
Expand Down Expand Up @@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "192.168.1.1",
Port: 9090,
},
MinTLSVersion: "TLS 1.3",
MinTLSVersion: "TLS 1.3",
AllowedClientSubjectName: "example.com",
},
},
}
Expand Down
29 changes: 28 additions & 1 deletion cns/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,28 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.
return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings)
}

// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name.
func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error {
if len(rawCerts) == 0 {
return errors.New("no client certificate provided")
}
// no client subject name provided, skip verification
if clientSubjectName == "" {
return nil
}

Copy link

Copilot AI Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No validation that rawCerts slice is not empty before accessing rawCerts[0]. This could cause a panic if an empty slice is passed.

Suggested change
if len(rawCerts) == 0 {
return errors.New("no client certificate provided")
}

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No validation that rawCerts slice is not empty

if len(rawCerts) == 0 {

???

cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return errors.Errorf("failed to parse certificate: %v", err)
}

err = cert.VerifyHostname(clientSubjectName)
if err != nil {
return errors.Errorf("failed to verify client certificate hostname: %v", err)
}
return nil
Comment on lines +174 to +178
Copy link

Copilot AI Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using VerifyHostname for client certificate subject name verification is incorrect. VerifyHostname is designed for server certificate validation and checks hostname against DNS names and CN fields. For client certificate subject name validation, you should directly compare against the certificate's Subject.CommonName or implement proper subject name matching logic.

Suggested change
err = cert.VerifyHostname(clientSubjectName)
if err != nil {
return errors.Errorf("failed to verify client certificate hostname: %v", err)
}
return nil
if cert.Subject.CommonName != clientSubjectName {
return errors.Errorf("client certificate subject name mismatch: got %q, want %q", cert.Subject.CommonName, clientSubjectName)
}
return nil

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one could be right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep this function is designed for server certificate validation actually. But its functionality is exactly what I want. It will validate the given strings with SANs /IPaddress of the certs.

}

func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {
tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings)
if err != nil {
Expand Down Expand Up @@ -202,8 +224,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error)
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName)
}
}

logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)

return tlsConfig, nil
Expand Down Expand Up @@ -254,6 +278,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName)
}
}

logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)
Expand Down
1 change: 1 addition & 0 deletions cns/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ func main() {
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
UseMTLS: cnsconfig.UseMTLS,
MinTLSVersion: cnsconfig.MinTLSVersion,
AllowedClientSubjectName: cnsconfig.AllowedClientSubjectName,
}
}

Expand Down
108 changes: 66 additions & 42 deletions cns/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,57 +133,81 @@ func TestNewService(t *testing.T) {
t.Run("NewServiceWithMutualTLS", func(t *testing.T) {
testCertFilePath := createTestCertificate(t)

config.TLSSettings = serverTLS.TlsSettings{
TLSPort: "10091",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
TLSSetting := serverTLS.TlsSettings{
TLSPort: "10091",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
AllowedClientSubjectName: "example.com",
}

svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)
TLSSettingWithDisallowedClientSN := serverTLS.TlsSettings{
TLSPort: "10092",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
AllowedClientSubjectName: "random.com",
}

svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")
runMutualTLSTest := func(tlsSettings serverTLS.TlsSettings, handshakeFailureExpected bool) {
config.TLSSettings = tlsSettings
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)

err = svc.Initialize(config)
t.Cleanup(func() {
svc.Uninitialize()
})
require.NoError(t, err)
svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")

err = svc.StartListener(config)
require.NoError(t, err)
err = svc.Initialize(config)
require.NoError(t, err)

mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
require.NoError(t, err)
err = svc.StartListener(config)
require.NoError(t, err)

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: mTLSConfig,
},
}
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
require.NoError(t, err)

// TLS listener
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: mTLSConfig,
},
}

// HTTP listener
httpClient := &http.Client{}
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err = httpClient.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
tlsURL := "https://localhost:" + tlsSettings.TLSPort
// TLS listener
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
t.Cleanup(func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
})
if handshakeFailureExpected {
require.Error(t, err)
require.ErrorContains(t, err, "failed to verify client certificate hostname")

} else {
require.NoError(t, err)
}

// HTTP listener
httpClient := &http.Client{}
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err = httpClient.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)

// Cleanup
svc.Uninitialize()
}
runMutualTLSTest(TLSSetting, false)
runMutualTLSTest(TLSSettingWithDisallowedClientSN, true)
})
}

Expand Down
1 change: 1 addition & 0 deletions server/tls/tlscertificate_retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type TlsSettings struct {
KeyVaultCertificateRefreshInterval time.Duration
UseMTLS bool
MinTLSVersion string
AllowedClientSubjectName string
}

func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {
Expand Down
Loading