diff --git a/cns/NetworkContainerContract.go b/cns/NetworkContainerContract.go index 8f5939c28e..0a9e9cb2e6 100644 --- a/cns/NetworkContainerContract.go +++ b/cns/NetworkContainerContract.go @@ -7,12 +7,13 @@ import ( "strconv" "strings" - "github.com/Azure/azure-container-networking/cns/types" - "github.com/Azure/azure-container-networking/crd/nodenetworkconfig/api/v1alpha" - "github.com/Azure/azure-container-networking/network/policy" "github.com/google/uuid" "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" + + "github.com/Azure/azure-container-networking/cns/types" + "github.com/Azure/azure-container-networking/crd/nodenetworkconfig/api/v1alpha" + "github.com/Azure/azure-container-networking/network/policy" ) // Container Network Service DNC Contract @@ -128,6 +129,7 @@ type CreateNetworkContainerRequest struct { AllowNCToHostCommunication bool EndpointPolicies []NetworkContainerRequestPolicies NCStatus v1alpha.NCStatus + SwiftV2PrefixOnNic bool // Indicates if is swiftv2 nc, PrefixOnNic scenario (isSwiftV2 && nc.Type == VNETBlock) NetworkInterfaceInfo NetworkInterfaceInfo //nolint // introducing new field for backendnic, to be used later by cni code } diff --git a/cns/kubecontroller/nodenetworkconfig/conversion.go b/cns/kubecontroller/nodenetworkconfig/conversion.go index f2f3d9e9cf..38cf9bb067 100644 --- a/cns/kubecontroller/nodenetworkconfig/conversion.go +++ b/cns/kubecontroller/nodenetworkconfig/conversion.go @@ -6,9 +6,10 @@ import ( "strconv" "strings" + "github.com/pkg/errors" + "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/crd/nodenetworkconfig/api/v1alpha" - "github.com/pkg/errors" ) var ( @@ -62,6 +63,7 @@ func CreateNCRequestFromDynamicNC(nc v1alpha.NetworkContainer) (*cns.CreateNetwo NetworkContainerid: nc.ID, NetworkContainerType: cns.Docker, Version: strconv.FormatInt(nc.Version, 10), //nolint:gomnd // it's decimal + SwiftV2PrefixOnNic: false, // Dynamic NCs don't use SwiftV2 PrefixOnNic IPConfiguration: cns.IPConfiguration{ IPSubnet: subnet, GatewayIPAddress: nc.DefaultGateway, diff --git a/cns/kubecontroller/nodenetworkconfig/conversion_linux.go b/cns/kubecontroller/nodenetworkconfig/conversion_linux.go index 9d425aa48f..c8e805c185 100644 --- a/cns/kubecontroller/nodenetworkconfig/conversion_linux.go +++ b/cns/kubecontroller/nodenetworkconfig/conversion_linux.go @@ -60,6 +60,7 @@ func createNCRequestFromStaticNCHelper(nc v1alpha.NetworkContainer, primaryIPPre GatewayIPv6Address: nc.DefaultGatewayV6, }, NCStatus: nc.Status, + SwiftV2PrefixOnNic: isSwiftV2 && nc.Type == v1alpha.VNETBlock, NetworkInterfaceInfo: cns.NetworkInterfaceInfo{ MACAddress: nc.MacAddress, }, diff --git a/cns/restserver/internalapi.go b/cns/restserver/internalapi.go index efefb3f2d3..0fbe299a12 100644 --- a/cns/restserver/internalapi.go +++ b/cns/restserver/internalapi.go @@ -12,17 +12,19 @@ import ( "net/http" "net/http/httptest" "reflect" + "runtime" "strconv" "strings" "time" + "github.com/pkg/errors" + "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/cns/logger" "github.com/Azure/azure-container-networking/cns/nodesubnet" "github.com/Azure/azure-container-networking/cns/types" "github.com/Azure/azure-container-networking/common" "github.com/Azure/azure-container-networking/crd/nodenetworkconfig/api/v1alpha" - "github.com/pkg/errors" ) const ( @@ -228,8 +230,8 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo return len(programmedNCs), errors.Wrap(err, "failed to get nc version list from nmagent") } - // Get IMDS NC versions for delegated NIC scenarios - imdsNCVersions, err := service.GetIMDSNCs(ctx) + // Get IMDS NC versions for delegated NIC scenarios. If any of the NMA API check calls, imds calls fails assume that nma build doesn't have the latest changes and create empty map + imdsNCVersions := service.getIMDSNCs(ctx) if err != nil { // If any of the NMA API check calls, imds calls fails assume that nma build doesn't have the latest changes and create empty map imdsNCVersions = make(map[string]string) @@ -685,18 +687,18 @@ func (service *HTTPRestService) isNCDetailsAPIExists(ctx context.Context) bool { } // GetIMDSNCs gets NC versions from IMDS and returns them as a map -func (service *HTTPRestService) GetIMDSNCs(ctx context.Context) (map[string]string, error) { +func (service *HTTPRestService) getIMDSNCs(ctx context.Context) map[string]string { imdsClient := service.imdsClient if imdsClient == nil { //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated logger.Errorf("IMDS client is not available") - return make(map[string]string), nil + return make(map[string]string) } // Check NC version support if !service.isNCDetailsAPIExists(ctx) { //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated logger.Errorf("IMDS does not support NC details API") - return make(map[string]string), nil + return make(map[string]string) } // Get all network interfaces from IMDS @@ -704,7 +706,7 @@ func (service *HTTPRestService) GetIMDSNCs(ctx context.Context) (map[string]stri if err != nil { //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated logger.Errorf("Failed to get network interfaces from IMDS: %v", err) - return make(map[string]string), nil + return make(map[string]string) } // Build ncs map from the network interfaces @@ -717,8 +719,26 @@ func (service *HTTPRestService) GetIMDSNCs(ctx context.Context) (map[string]stri if ncID != "" { ncs[ncID] = PrefixOnNicNCVersion // for prefix on nic version scenario nc version is 1 + } else if runtime.GOOS == "windows" && service.isPrefixonNicSwiftV2() { + err := service.setPrefixOnNICRegistry(true, iface.MacAddress.String()) + if err != nil { + //nolint:staticcheck // SA1019: suppress deprecated logger.Debugf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Debugf("failed to add PrefixOnNic keys to Windows registry: %w", err) + } } } - return ncs, nil + return ncs +} + +// Check whether NC is SwiftV2 NIC associated NC and prefix on nic is enabled +func (service *HTTPRestService) isPrefixonNicSwiftV2() bool { + for i := range service.state.ContainerStatus { + req := service.state.ContainerStatus[i].CreateNetworkContainerRequest + + if req.SwiftV2PrefixOnNic { + return true + } + } + return false } diff --git a/cns/restserver/internalapi_linux.go b/cns/restserver/internalapi_linux.go index 0abae0a72a..0020ada454 100644 --- a/cns/restserver/internalapi_linux.go +++ b/cns/restserver/internalapi_linux.go @@ -6,13 +6,14 @@ import ( "os/exec" "strconv" + goiptables "github.com/coreos/go-iptables/iptables" + "github.com/pkg/errors" + "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/cns/logger" "github.com/Azure/azure-container-networking/cns/types" "github.com/Azure/azure-container-networking/iptables" "github.com/Azure/azure-container-networking/network/networkutils" - goiptables "github.com/coreos/go-iptables/iptables" - "github.com/pkg/errors" ) const SWIFTPOSTROUTING = "SWIFT-POSTROUTING" @@ -181,3 +182,11 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer func (service *HTTPRestService) setVFForAccelnetNICs() error { return nil } + +func (service *HTTPRestService) setPrefixOnNICRegistry(enabled bool, infraNicMacAddress string) error { + // Assigning parameters to '_' to avoid unused parameter linting errors. + // These parameters are only used in the Windows implementation. + _ = enabled + _ = infraNicMacAddress + return nil +} diff --git a/cns/restserver/internalapi_test.go b/cns/restserver/internalapi_test.go index 4df797a498..ec9295e358 100644 --- a/cns/restserver/internalapi_test.go +++ b/cns/restserver/internalapi_test.go @@ -10,12 +10,19 @@ import ( "net" "os" "reflect" + "runtime" "strconv" "strings" "sync" "testing" "time" + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/cns/common" "github.com/Azure/azure-container-networking/cns/configuration" @@ -25,11 +32,6 @@ import ( "github.com/Azure/azure-container-networking/crd/nodenetworkconfig/api/v1alpha" nma "github.com/Azure/azure-container-networking/nmagent" "github.com/Azure/azure-container-networking/store" - "github.com/google/uuid" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" ) const ( @@ -1680,3 +1682,128 @@ func setupIMDSMockAPIsWithCustomIDs(svc *HTTPRestService, interfaceIDs []string) // Return cleanup function return func() { svc.imdsClient = originalIMDS } } + +// TestSyncHostNCVersionWithWindowsSwiftV2 tests SyncHostNCVersion and verifies it calls Windows SwiftV2 PrefixOnNic scenario +func TestSyncHostNCVersionWithWindowsSwiftV2(t *testing.T) { + testSvc := getTestService(cns.Kubernetes) + + // Set up test NCs with different scenarios + regularNCID := "regular-nc-id" + swiftV2NCID := "swift-v2-vnet-block-nc" + + // Initialize ContainerStatus map if nil + if testSvc.state.ContainerStatus == nil { + testSvc.state.ContainerStatus = make(map[string]containerstatus) + } + + // Add a regular NC + testSvc.state.ContainerStatus[regularNCID] = containerstatus{ + ID: regularNCID, + CreateNetworkContainerRequest: cns.CreateNetworkContainerRequest{ + NetworkContainerid: regularNCID, + SwiftV2PrefixOnNic: false, + NetworkContainerType: cns.Docker, + Version: "2", + }, + HostVersion: "1", + } + + // Add a SwiftV2 VNETBlock NC that should trigger Windows registry operations + testSvc.state.ContainerStatus[swiftV2NCID] = containerstatus{ + ID: swiftV2NCID, + CreateNetworkContainerRequest: cns.CreateNetworkContainerRequest{ + NetworkContainerid: swiftV2NCID, + SwiftV2PrefixOnNic: true, + NetworkContainerType: cns.Docker, + Version: "2", + }, + HostVersion: "1", + } + + // Set up mock NMAgent with NC versions + mockNMA := &fakes.NMAgentClientFake{} + mockNMA.GetNCVersionListF = func(_ context.Context) (nma.NCVersionList, error) { + return nma.NCVersionList{ + Containers: []nma.NCVersion{ + { + NetworkContainerID: regularNCID, + Version: "2", + }, + { + NetworkContainerID: swiftV2NCID, + Version: "2", + }, + }, + }, nil + } + testSvc.nma = mockNMA + + // Set up mock IMDS client for Windows SwiftV2 scenario + mac1, _ := net.ParseMAC("AA:BB:CC:DD:EE:FF") + mac2, _ := net.ParseMAC("11:22:33:44:55:66") + + interfaceMap := map[string]imds.NetworkInterface{ + "interface1": { + InterfaceCompartmentID: "", // Empty for Windows condition + MacAddress: imds.HardwareAddr(mac1), + }, + "interface2": { + InterfaceCompartmentID: "nc-with-compartment-id", + MacAddress: imds.HardwareAddr(mac2), + }, + } + mockIMDS := &mockIMDSAdapter{ + mock: &struct { + networkInterfaces func(_ context.Context) ([]imds.NetworkInterface, error) + imdsVersions func(_ context.Context) (*imds.APIVersionsResponse, error) + }{ + networkInterfaces: func(_ context.Context) ([]imds.NetworkInterface, error) { + var interfaces []imds.NetworkInterface + for _, iface := range interfaceMap { + interfaces = append(interfaces, iface) + } + return interfaces, nil + }, + imdsVersions: func(_ context.Context) (*imds.APIVersionsResponse, error) { + return &imds.APIVersionsResponse{ + APIVersions: []string{expectedIMDSAPIVersion}, + }, nil + }, + }, + } + + // Replace the IMDS client + originalIMDS := testSvc.imdsClient + testSvc.imdsClient = mockIMDS + defer func() { testSvc.imdsClient = originalIMDS }() + + // Verify preconditions + assert.True(t, testSvc.isPrefixonNicSwiftV2(), "isPrefixonNicSwiftV2() should return true") + + ctx := context.Background() + testSvc.SyncHostNCVersion(ctx, cns.CRD) + + // Verify that NC versions were updated + updatedRegularNC := testSvc.state.ContainerStatus[regularNCID] + updatedSwiftV2NC := testSvc.state.ContainerStatus[swiftV2NCID] + + assert.Equal(t, "2", updatedRegularNC.HostVersion, "Regular NC host version should be updated to 2") + assert.Equal(t, "2", updatedSwiftV2NC.HostVersion, "SwiftV2 NC host version should be updated to 2") + + imdsNCs := testSvc.getIMDSNCs(ctx) + + // Verify IMDS results + assert.Contains(t, imdsNCs, "nc-with-compartment-id", "NC with compartment ID should be in results") + assert.Equal(t, PrefixOnNicNCVersion, imdsNCs["nc-with-compartment-id"], "NC should have expected version") + + // Log the conditions that would trigger Windows registry operations + isWindows := runtime.GOOS == "windows" + hasSwiftV2PrefixOnNic := testSvc.isPrefixonNicSwiftV2() + + t.Logf("Windows SwiftV2 PrefixOnNic conditions: (runtime.GOOS == 'windows' && service.isPrefixonNicSwiftV2()): %t", + isWindows && hasSwiftV2PrefixOnNic) + + // Test with no SwiftV2 NCs + delete(testSvc.state.ContainerStatus, swiftV2NCID) + assert.False(t, testSvc.isPrefixonNicSwiftV2(), "isPrefixonNicSwiftV2() should return false without SwiftV2 NCs") +} diff --git a/cns/restserver/internalapi_windows.go b/cns/restserver/internalapi_windows.go index 245053c8d6..446be42d6e 100644 --- a/cns/restserver/internalapi_windows.go +++ b/cns/restserver/internalapi_windows.go @@ -5,15 +5,22 @@ import ( "fmt" "time" - "github.com/Azure/azure-container-networking/cns" - "github.com/Azure/azure-container-networking/cns/types" "github.com/Microsoft/hcsshim" "github.com/pkg/errors" + "golang.org/x/sys/windows/registry" + + "github.com/Azure/azure-container-networking/cns" + "github.com/Azure/azure-container-networking/cns/logger" + "github.com/Azure/azure-container-networking/cns/types" ) const ( // timeout for powershell command to return the interfaces list - pwshTimeout = 120 * time.Second + pwshTimeout = 120 * time.Second + hnsRegistryPath = `SYSTEM\CurrentControlSet\Services\HNS\wcna_state\config` + prefixOnNicRegistryPath = `SYSTEM\CurrentControlSet\Services\HNS\wcna_state\config\PrefixOnNic` + infraNicIfName = "eth0" + enableSNAT = false ) var errUnsupportedAPI = errors.New("unsupported api") @@ -75,3 +82,74 @@ func (service *HTTPRestService) getPrimaryNICMACAddress() (string, error) { } return macAddress, nil } + +func (service *HTTPRestService) enablePrefixOnNic(enabled bool) error { + return service.setRegistryValue(prefixOnNicRegistryPath, "enabled", enabled) +} + +func (service *HTTPRestService) setInfraNicMacAddress(macAddress string) error { + return service.setRegistryValue(prefixOnNicRegistryPath, "infra_nic_mac_address", macAddress) +} + +func (service *HTTPRestService) setInfraNicIfName(ifName string) error { + return service.setRegistryValue(prefixOnNicRegistryPath, "infra_nic_ifname", ifName) +} + +func (service *HTTPRestService) setEnableSNAT(enabled bool) error { + return service.setRegistryValue(hnsRegistryPath, "EnableSNAT", enabled) +} + +func (service *HTTPRestService) setPrefixOnNICRegistry(enablePrefixOnNic bool, infraNicMacAddress string) error { + if err := service.enablePrefixOnNic(enablePrefixOnNic); err != nil { + return fmt.Errorf("failed to set enablePrefixOnNic key to windows registry: %w", err) + } + + if err := service.setInfraNicMacAddress(infraNicMacAddress); err != nil { + return fmt.Errorf("failed to set InfraNicMacAddress key to windows registry: %w", err) + } + + if err := service.setInfraNicIfName(infraNicIfName); err != nil { + return fmt.Errorf("failed to set InfraNicIfName key to windows registry: %w", err) + } + + if err := service.setEnableSNAT(enableSNAT); err != nil { + return fmt.Errorf("failed to set EnableSNAT key to windows registry: %w", err) + } + + return nil +} + +func (service *HTTPRestService) setRegistryValue(registryPath, keyName string, value interface{}) error { + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, registryPath, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("failed to create/open registry key %s: %w", registryPath, err) + } + defer key.Close() + + switch v := value.(type) { + case string: + err = key.SetStringValue(keyName, v) + case bool: + dwordValue := uint32(0) + if v { + dwordValue = 1 + } + err = key.SetDWordValue(keyName, dwordValue) + case uint32: + err = key.SetDWordValue(keyName, v) + case int: + case int: + if v < 0 || v > int(^uint32(0)) { + return fmt.Errorf("int value %d overflows uint32 for registry key %s", v, keyName) + } + err = key.SetDWordValue(keyName, uint32(v)) + default: + return fmt.Errorf("unsupported value type for registry key %s: %T", keyName, value) + } + if err != nil { + return fmt.Errorf("failed to set registry value '%s': %w", keyName, err) + } + + logger.Printf("[setRegistryValue] Set %s\\%s = %v", registryPath, keyName, value) + return nil +}