diff --git a/cmd/join.go b/cmd/join.go index 89887c1b..d0482f4a 100644 --- a/cmd/join.go +++ b/cmd/join.go @@ -57,6 +57,8 @@ func init() { joinCmd.Flags().StringP(registerFlags.Name, "o", "", "sets host name") joinCmd.Flags().StringP(registerFlags.Interface, "I", "", "sets netmaker interface to use on host") joinCmd.Flags().StringP(registerFlags.Firewall, "f", "", "selects firewall to use on host: iptables/nftables") + joinCmd.Flags().BoolP(registerFlags.ForcePrivate, "P", false, "forces Windows network interface to be classified as Private") + joinCmd.Flags().StringP(registerFlags.ProfileName, "N", "", "sets Windows network profile name for the interface") rootCmd.AddCommand(joinCmd) } diff --git a/cmd/register.go b/cmd/register.go index 29879b0d..5425afe0 100644 --- a/cmd/register.go +++ b/cmd/register.go @@ -21,34 +21,38 @@ import ( ) var registerFlags = struct { - Firewall string - Server string - User string - Token string - Network string - AllNetworks string - EndpointIP string - EndpointIP6 string - Port string - MTU string - StaticPort string - Static string - Interface string - Name string + Firewall string + Server string + User string + Token string + Network string + AllNetworks string + EndpointIP string + EndpointIP6 string + Port string + MTU string + StaticPort string + Static string + Interface string + Name string + ForcePrivate string + ProfileName string }{ - Server: "server", - User: "user", - Token: "token", - Network: "net", - AllNetworks: "all-networks", - EndpointIP: "endpoint-ip", - Port: "port", - MTU: "mtu", - StaticPort: "static-port", - Static: "static-endpoint", - Name: "name", - Interface: "interface", - Firewall: "firewall", + Server: "server", + User: "user", + Token: "token", + Network: "net", + AllNetworks: "all-networks", + EndpointIP: "endpoint-ip", + Port: "port", + MTU: "mtu", + StaticPort: "static-port", + Static: "static-endpoint", + Name: "name", + Interface: "interface", + Firewall: "firewall", + ForcePrivate: "force-private", + ProfileName: "profile-name", } // registerCmd represents the register command @@ -147,6 +151,22 @@ func setHostFields(cmd *cobra.Command) { config.Netclient().FirewallInUse = firewall } } + if forcePrivate, err := cmd.Flags().GetBool(registerFlags.ForcePrivate); err == nil { + if ncutils.IsWindows() { + config.Netclient().ForcePrivateProfile = forcePrivate + } + } + if profileName, err := cmd.Flags().GetString(registerFlags.ProfileName); err == nil && profileName != "" { + if ncutils.IsWindows() { + config.Netclient().InterfaceProfileName = profileName + } + } + // Save config if any Windows-specific settings were changed + if ncutils.IsWindows() && (cmd.Flags().Changed(registerFlags.ForcePrivate) || cmd.Flags().Changed(registerFlags.ProfileName)) { + if err := config.WriteNetclientConfig(); err != nil { + logger.Log(0, "failed to save config after setting Windows interface settings", err.Error()) + } + } } func validateIface(iface string) bool { if iface == "" { diff --git a/config/config.go b/config/config.go index 50c5ef98..5e93072f 100644 --- a/config/config.go +++ b/config/config.go @@ -98,6 +98,10 @@ type Config struct { NameServers []string `json:"name_servers" yaml:"name_servers"` DNSSearch string `json:"dns_search" yaml:"dns_search"` DNSOptions string `json:"dns_options" yaml:"dns_options"` + // ForcePrivateProfile - force Windows network interface to be classified as Private + ForcePrivateProfile bool `json:"force_private_profile" yaml:"force_private_profile"` + // InterfaceProfileName - Windows network profile name for the interface + InterfaceProfileName string `json:"interface_profile_name" yaml:"interface_profile_name"` } func init() { diff --git a/wireguard/wireguard_windows.go b/wireguard/wireguard_windows.go index f95416d7..f83fa8fc 100644 --- a/wireguard/wireguard_windows.go +++ b/wireguard/wireguard_windows.go @@ -5,24 +5,197 @@ import ( "fmt" "net" "net/netip" + "os/exec" "strconv" "strings" + "sync" + "time" "github.com/gravitl/netclient/config" "github.com/gravitl/netclient/ncutils" "github.com/gravitl/netmaker/logger" "golang.org/x/exp/slog" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" "golang.zx2c4.com/wireguard/windows/driver" ) // TODO: update from netsh to a more programmatic approach. +// SetInterfaceProfileName - sets the Windows network profile name for the interface +func SetInterfaceProfileName(ifaceName string, profileName string) error { + if profileName == "" { + return nil + } + + // Wait a bit for Windows to create the network profile after interface creation + // Retry up to 8 times with shorter delays (starts with 500ms, then 1s) + maxRetries := 8 + var lastErr error + for i := 0; i < maxRetries; i++ { + if i > 0 { + // Use shorter delay on first retry, then standard 1 second + delay := 1 * time.Second + if i == 1 { + delay = 500 * time.Millisecond + } + time.Sleep(delay) + } + + // Use registry method (most reliable for profile name) + err := setInterfaceProfileNameViaRegistry(ifaceName, profileName) + if err == nil { + slog.Info("set interface profile name via registry", "interface", ifaceName, "profileName", profileName) + return nil + } + lastErr = err + } + + return fmt.Errorf("failed to set profile name after %d attempts: %w", maxRetries, lastErr) +} + +// setInterfaceProfileNameViaRegistry - sets profile name by enumerating registry profiles +func setInterfaceProfileNameViaRegistry(ifaceName string, profileName string) error { + // Retry finding and updating the profile with delays + maxRetries := 5 + var err error + + for i := 0; i < maxRetries; i++ { + if i > 0 { + // Use shorter delay on first retry + delay := 1 * time.Second + if i == 1 { + delay = 500 * time.Millisecond + } + time.Sleep(delay) + } + + // Enumerate registry to find and update the profile + err = findAndUpdateProfileName(ifaceName, profileName) + if err == nil { + slog.Info("set interface profile name via registry", "interface", ifaceName, "profileName", profileName) + return nil + } + slog.Debug("failed to find/update profile in registry, retrying", "attempt", i+1, "error", err) + } + + return fmt.Errorf("failed to set profile name after %d attempts: %w", maxRetries, err) +} + +// findAndUpdateProfileName - enumerates registry profiles to find one matching interface name and update it +func findAndUpdateProfileName(ifaceName string, profileName string) error { + parentPath := `SOFTWARE\Microsoft\Windows NT\CurrentVersion\NetworkList\Profiles` + // Need READ permission (includes QUERY_VALUE and ENUMERATE_SUB_KEYS) to use Stat() method + parentKey, err := registry.OpenKey(registry.LOCAL_MACHINE, parentPath, registry.ALL_ACCESS) + if err != nil { + return fmt.Errorf("failed to open profiles registry key: %w", err) + } + defer parentKey.Close() + + // Get subkey count using Stat() for efficient enumeration + keyInfo, err := parentKey.Stat() + if err != nil { + return fmt.Errorf("failed to get registry key info: %w", err) + } + + // Read all subkeys at once since we know the count + subKeyNames, err := parentKey.ReadSubKeyNames(int(keyInfo.SubKeyCount)) + if err != nil { + return fmt.Errorf("failed to read subkeys: %w", err) + } + + // First, check if the profile name is already set to the target name + // This handles the case where a previous attempt succeeded but we're retrying + var matchingGUID string + var alreadySet bool + for _, guid := range subKeyNames { + subKey, err := registry.OpenKey(parentKey, guid, registry.QUERY_VALUE) + if err != nil { + continue + } + + currentProfileName, _, err := subKey.GetStringValue("ProfileName") + if err != nil { + subKey.Close() + continue + } + + // If profile name already matches target, we're done + if currentProfileName == profileName { + alreadySet = true + subKey.Close() + break + } + + // Check if ProfileName matches the interface name (for initial update) + if currentProfileName == ifaceName { + // Found a profile matching the interface name + matchingGUID = strings.Trim(guid, "{}") + } + subKey.Close() + } + + // If already set to target name, return success + if alreadySet { + return nil + } + + // If no matching profile found, return error + if matchingGUID == "" { + return fmt.Errorf("no profile found with ProfileName matching interface %s", ifaceName) + } + + // Update the profile name + profilePath := parentPath + `\` + "{" + matchingGUID + "}" + profileKey, err := registry.OpenKey(registry.LOCAL_MACHINE, profilePath, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("failed to open profile key: %w", err) + } + defer profileKey.Close() + + err = profileKey.SetStringValue("ProfileName", profileName) + if err != nil { + return fmt.Errorf("failed to set profile name: %w", err) + } + + // Also set Description to match + err = profileKey.SetStringValue("Description", profileName) + if err != nil { + slog.Debug("failed to set Description, continuing", "error", err) + } + + slog.Debug("updated profile name in registry", "interface", ifaceName, "profileGUID", matchingGUID, "profileName", profileName) + return nil +} + +// SetInterfacePrivateProfile - sets the Windows network interface profile to Private +func SetInterfacePrivateProfile(ifaceName string) error { + // Use PowerShell to set the network profile to Private + psCmd := fmt.Sprintf("Set-NetConnectionProfile -InterfaceAlias '%s' -NetworkCategory Private -ErrorAction SilentlyContinue", ifaceName) + cmd := exec.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", psCmd) + output, err := cmd.Output() + if err != nil { + // Try alternative approach if the first one fails + psCmd2 := fmt.Sprintf("$profile = Get-NetConnectionProfile -InterfaceAlias '%s' -ErrorAction SilentlyContinue; if ($profile) { Set-NetConnectionProfile -InterfaceAlias '%s' -NetworkCategory Private -ErrorAction Stop }", ifaceName, ifaceName) + cmd2 := exec.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", psCmd2) + output, err = cmd2.Output() + if err != nil { + slog.Error("failed to set interface profile to Private", "interface", ifaceName, "error", err, "output", string(output)) + return fmt.Errorf("failed to set interface profile: %w", err) + } + } + slog.Info("set interface profile to Private", "interface", ifaceName) + return nil +} + // NCIface.Create - makes a new Wireguard interface and sets given addresses func (nc *NCIface) Create() error { wgMutex.Lock() defer wgMutex.Unlock() + // Flush network caches before creating interface to ensure clean state + FlushWindowsNetworkCaches() + adapter, err := driver.OpenAdapter(ncutils.GetInterfaceName()) if err != nil { slog.Info("creating Windows tunnel") @@ -37,8 +210,22 @@ func (nc *NCIface) Create() error { } adapter, err = driver.CreateAdapter(ncutils.GetInterfaceName(), "WireGuard", &windowsGUID) if err != nil { - slog.Error("creating adapter error: ", "error", err) - return err + // Check if adapter already exists - try to open it again + if strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "Cannot create a file when that file already exists") { + slog.Info("adapter already exists, attempting to open it") + // Retry opening the adapter - it might have been created by another process + var openErr error + adapter, openErr = driver.OpenAdapter(ncutils.GetInterfaceName()) + if openErr != nil { + slog.Error("creating adapter error (adapter exists but cannot be opened): ", "error", err, "openError", openErr) + return fmt.Errorf("adapter exists but cannot be opened: %w (original error: %v)", openErr, err) + } + slog.Info("successfully opened existing adapter") + err = nil // Clear the error since we successfully opened the adapter + } else { + slog.Error("creating adapter error: ", "error", err) + return err + } } } else { slog.Info("re-using existing adapter") @@ -46,7 +233,34 @@ func (nc *NCIface) Create() error { slog.Info("created Windows tunnel") nc.Iface = adapter - return adapter.SetAdapterState(driver.AdapterStateUp) + if err := adapter.SetAdapterState(driver.AdapterStateUp); err != nil { + return err + } + + // Set network profile settings asynchronously (non-blocking) + go func() { + ifaceName := ncutils.GetInterfaceName() + + // Wait for Windows to create and register the network profile in the registry + // Use a shorter initial wait, then check if interface exists + time.Sleep(2 * time.Second) + + // Set network profile to Private if force flag is set + if config.Netclient().ForcePrivateProfile { + if err := SetInterfacePrivateProfile(ifaceName); err != nil { + slog.Warn("failed to set interface profile to Private", "error", err) + } + } + + // Set interface profile name if configured + if config.Netclient().InterfaceProfileName != "" { + if err := SetInterfaceProfileName(ifaceName, config.Netclient().InterfaceProfileName); err != nil { + slog.Warn("failed to set interface profile name", "error", err) + } + } + }() + + return nil } // NCIface.ApplyAddrs - applies addresses to windows tunnel ifaces, unused currently @@ -495,6 +709,53 @@ func restoreInternetGwV4() (err error) { return config.WriteNetclientConfig() } +// FlushWindowsNetworkCaches - flushes Windows network caches (route, ARP, NetBIOS, DNS) +func FlushWindowsNetworkCaches() { + var wg sync.WaitGroup + + // Clear route destination cache + wg.Add(1) + go func() { + defer wg.Done() + _, _ = ncutils.RunCmd("netsh interface ip delete destinationcache", false) + }() + + // Clear ARP cache + wg.Add(1) + go func() { + defer wg.Done() + _, _ = ncutils.RunCmd("arp -d *", false) + }() + + // Clear NetBIOS cache + wg.Add(1) + go func() { + defer wg.Done() + _, _ = ncutils.RunCmd("nbtstat -R", false) + }() + + wg.Add(1) + go func() { + defer wg.Done() + _, _ = ncutils.RunCmd("nbtstat -RR", false) + }() + + // Flush DNS cache (run sequentially with registerdns to avoid potential conflicts) + wg.Add(1) + go func() { + defer wg.Done() + // Flush DNS cache first + _, _ = ncutils.RunCmd("ipconfig /flushdns", false) + // Then re-register DNS + _, _ = ncutils.RunCmd("ipconfig /registerdns", false) + }() + + // Wait for all commands to complete + wg.Wait() + + slog.Info("flushed Windows network caches") +} + // NCIface.Close - closes the managed WireGuard interface func (nc *NCIface) Close() { wgMutex.Lock() @@ -520,6 +781,9 @@ func (nc *NCIface) Close() { } } } + + // Flush network caches when closing interface + FlushWindowsNetworkCaches() } // NCIface.SetMTU - sets the MTU of the windows WireGuard Iface adapter