diff --git a/cli/internal/config/nodebootstrap/nodebootstrap.go b/cli/internal/config/nodebootstrap/nodebootstrap.go index 753472d..9c31d4e 100644 --- a/cli/internal/config/nodebootstrap/nodebootstrap.go +++ b/cli/internal/config/nodebootstrap/nodebootstrap.go @@ -4,9 +4,11 @@ import ( "context" "fmt" "io" + "strings" "github.com/spf13/cobra" + "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/nebius/instance" "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/userdata/flex" "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/userdata/ubuntu" "github.com/Azure/aks-flex/plugin/pkg/util/cloudinit" @@ -25,6 +27,8 @@ var flagEnableNvidiaGPURuntime bool var flagVariant string var flagArch string var flagKubeVersion string +var flagWireguardIP string +var flagWireguardSite string func init() { r.Handle("ubuntu", writeUbuntuUserData) @@ -37,6 +41,10 @@ func init() { "Kubernetes version for the downloaded binaries (e.g. 1.33.3).") Command.Flags().StringVar(&flagVariant, "variant", variantCloudInit, fmt.Sprintf("Output variant: %q produces cloud-init YAML user data, %q produces an equivalent standalone bash script.", variantCloudInit, variantScript)) + Command.Flags().StringVar(&flagWireguardIP, "wireguard-ip", "", + "WireGuard peer IP address for the node.") + Command.Flags().StringVar(&flagWireguardSite, "wireguard-site", "remote", + "WireGuard site for the node.") } // marshalUserData marshals the cloud-init UserData according to the selected @@ -62,15 +70,37 @@ func marshalUserData(ud *cloudinit.UserData, w io.Writer) error { } func writeFlexUserData(ctx context.Context, w io.Writer) error { + kubeadmConfig := configcmd.DefaultKubeadmConfig(ctx) + if flagWireguardIP != "" { + kubeadmConfig.SetNodeIp(flagWireguardIP) + } + ud, err := flex.UserData( flex.WithEnableNvidiaGPURuntime(flagEnableNvidiaGPURuntime), flex.WithArch(flagArch), flex.WithKubeVersion(flagKubeVersion), - flex.WithKubeadmConfig(configcmd.DefaultKubeadmConfig(ctx)), + flex.WithKubeadmConfig(kubeadmConfig), ) if err != nil { return fmt.Errorf("generating flex userdata: %w", err) } + + if flagWireguardIP != "" { + ud.Packages = append(ud.Packages, "wireguard", "wireguard-tools") + ud.WriteFiles = append(ud.WriteFiles, &cloudinit.WriteFile{ + Path: "/root/wg-spoke.sh", + Content: instance.GetWgSpokeScript(), + Permissions: "0755", + }) + ud.RunCmd = append(ud.RunCmd, strings.Join([]string{ + "export ANNOTATION_PREFIX='stretch.azure.com/wireguard-'", + fmt.Sprintf("export WG_ADDRESS='%s/32'", flagWireguardIP), + fmt.Sprintf("export WG_SITE='%s'", flagWireguardSite), + "export WG_DAEMONIZE='true'", + "/root/wg-spoke.sh", + }, "\n")) + } + return marshalUserData(ud, w) } diff --git a/plugin/pkg/services/agentpools/nebius/instance/agentpools.go b/plugin/pkg/services/agentpools/nebius/instance/agentpools.go index a517cc5..268834b 100644 --- a/plugin/pkg/services/agentpools/nebius/instance/agentpools.go +++ b/plugin/pkg/services/agentpools/nebius/instance/agentpools.go @@ -28,6 +28,10 @@ import ( //go:embed assets/wg-spoke.sh var wgSpokeScript string +func GetWgSpokeScript() string { + return wgSpokeScript +} + var _ api.Object = (*AgentPool)(nil) type agentPoolsServer struct {