diff --git a/cmd/main.go b/cmd/main.go index 3b75421f..216b3ff7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -109,7 +109,7 @@ func init() { } karpenterScheme.Register(&karpv1.NodeClaim{}, &karpv1.NodeClaimList{}) karpenterScheme.Register(&karpv1.NodePool{}, &karpv1.NodePoolList{}) - karpenterScheme.AddToScheme(scheme) + utilruntime.Must(karpenterScheme.AddToScheme(scheme)) } //nolint:gocyclo diff --git a/internal/utils/compose.go b/internal/utils/compose.go index 8802c6ce..cef36f89 100644 --- a/internal/utils/compose.go +++ b/internal/utils/compose.go @@ -78,11 +78,6 @@ type TensorFusionInfo struct { EnabledReplicas *int32 WorkloadName string ContainerNames []string - GenWorkload bool - - // Pod mutating webhook can not get Pod UID sometimes, - // thus need pod controller to set the owner reference - PendingSetPodAsOwner bool } func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo TensorFusionInfo) { diff --git a/internal/utils/owner_ref_utils.go b/internal/utils/owner_ref_utils.go index 5a4fe9df..97c8cf4f 100644 --- a/internal/utils/owner_ref_utils.go +++ b/internal/utils/owner_ref_utils.go @@ -96,3 +96,43 @@ func FindFirstLevelOwnerReference(obj metav1.Object) *metav1.OwnerReference { } return &ownerRef } + +// FindRootControllerRef recursively finds the root controller reference for a given object (e.g. Pod). +func FindRootControllerRef(ctx context.Context, c client.Client, obj metav1.Object) (*metav1.OwnerReference, error) { + if metav1.GetControllerOfNoCopy(obj) == nil { + return nil, nil + } + + namespace := obj.GetNamespace() + current := obj + for { + controllerRef := metav1.GetControllerOf(current) + if controllerRef == nil { + if rObj, ok := current.(runtime.Object); ok { + gvk := rObj.GetObjectKind().GroupVersionKind() + return metav1.NewControllerRef(current, gvk), nil + } else { + return nil, fmt.Errorf("not a runtime.Object") + } + } + + unObj := &unstructured.Unstructured{} + unObj.SetAPIVersion(controllerRef.APIVersion) + unObj.SetKind(controllerRef.Kind) + err := c.Get(ctx, client.ObjectKey{Name: controllerRef.Name, Namespace: namespace}, unObj) + if err != nil { + // if not found, return controllerRef as root + if errors.IsNotFound(err) { + return controllerRef, nil + } + return nil, fmt.Errorf("get controller object: %w", err) + } + + // Cast back to metav1.Object if possible + if metaObj, ok := any(unObj).(metav1.Object); ok { + current = metaObj + } else { + return nil, fmt.Errorf("unexpected type for controller object %s/%s", controllerRef.Kind, controllerRef.Name) + } + } +} diff --git a/internal/utils/owner_ref_utils_test.go b/internal/utils/owner_ref_utils_test.go index 5b77b531..16d8386b 100644 --- a/internal/utils/owner_ref_utils_test.go +++ b/internal/utils/owner_ref_utils_test.go @@ -140,3 +140,127 @@ func TestFindRootOwnerReference(t *testing.T) { require.Equal(t, "ReplicaSet", rootRef.Kind) }) } + +func TestFindRootControllerRef(t *testing.T) { + // Prepare the scheme + sch := runtime.NewScheme() + require.NoError(t, corev1.AddToScheme(sch)) + require.NoError(t, appsv1.AddToScheme(sch)) + + t.Run("no controller returns nil", func(t *testing.T) { + pod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + UID: "uid-pod", + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build() + + rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.Nil(t, rootRef) + }) + + t.Run("hierarchy returns deployment", func(t *testing.T) { + controller := true + deployment := &appsv1.Deployment{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "mydeploy", + Namespace: "default", + UID: "uid-deploy", + }, + } + + rs := &appsv1.ReplicaSet{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "apps/v1", + Kind: "ReplicaSet", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "myrs", + Namespace: "default", + UID: "uid-rs", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "mydeploy", + UID: deployment.UID, + Controller: &controller, + }, + }, + }, + } + + pod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + UID: "uid-pod", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "myrs", + UID: rs.UID, + Controller: &controller, + }, + }, + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod, rs, deployment).Build() + + rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.NotNil(t, rootRef) + require.Equal(t, "mydeploy", rootRef.Name) + require.Equal(t, "Deployment", rootRef.Kind) + }) + + t.Run("missing controller returns last found ref", func(t *testing.T) { + controller := true + pod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + UID: "uid-pod", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "missing-rs", + UID: "uid-missing", + Controller: &controller, + }, + }, + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build() + + rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.NotNil(t, rootRef) + require.Equal(t, "missing-rs", rootRef.Name) + require.Equal(t, "ReplicaSet", rootRef.Kind) + }) +} diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index 6c54113d..21f01ec9 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -29,7 +29,6 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/strategicpatch" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -53,7 +52,7 @@ func SetupPodWebhookWithManager(mgr ctrl.Manager, portAllocator *portallocator.P webhookServer.Register("/mutate-v1-pod", &admission.Webhook{ Handler: &TensorFusionPodMutator{ - decoder: admission.NewDecoder(runtime.NewScheme()), + decoder: admission.NewDecoder(mgr.GetScheme()), Client: mgr.GetClient(), portAllocator: portAllocator, }, @@ -122,19 +121,18 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque podCounterAnnotationKey = podCounterKey } - if tfInfo.PendingSetPodAsOwner { - pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation] = tfInfo.WorkloadName - } - pool := &tfv1.GPUPool{} if err := m.Client.Get(ctx, client.ObjectKey{Name: tfInfo.Profile.PoolName}, pool); err != nil { return admission.Errored(http.StatusInternalServerError, fmt.Errorf("gpu pool(%s) does not exist", tfInfo.Profile.PoolName)) } - workload := &tfv1.TensorFusionWorkload{} - if tfInfo.GenWorkload { - if err := m.createOrUpdateWorkload(ctx, pod, &tfInfo, workload, pool); err != nil { - return admission.Errored(http.StatusInternalServerError, fmt.Errorf("create tf workload: %w", err)) + if workload, err := m.createOrUpdateWorkload(ctx, pod, &tfInfo, pool); err != nil { + return admission.Errored(http.StatusInternalServerError, fmt.Errorf("create tf workload: %w", err)) + } else { + // Pod mutating webhook can not get Pod UID, + // thus need pod controller to set the controller reference + if controllerRef := metav1.GetControllerOfNoCopy(workload); controllerRef == nil { + pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation] = tfInfo.WorkloadName } } @@ -201,7 +199,11 @@ func (m *TensorFusionPodMutator) InjectDecoder(d admission.Decoder) error { return nil } -func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod *corev1.Pod, tfInfo *utils.TensorFusionInfo, workload *tfv1.TensorFusionWorkload, pool *tfv1.GPUPool) error { +func (m *TensorFusionPodMutator) createOrUpdateWorkload( + ctx context.Context, + pod *corev1.Pod, + tfInfo *utils.TensorFusionInfo, + pool *tfv1.GPUPool) (*tfv1.TensorFusionWorkload, error) { // Create the desired spec for comparison desiredSpec := tfv1.WorkloadProfileSpec{ Replicas: nil, @@ -214,13 +216,12 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod AutoScalingConfig: tfInfo.Profile.AutoScalingConfig, } + workload := &tfv1.TensorFusionWorkload{} err := m.Client.Get(ctx, client.ObjectKey{Name: tfInfo.WorkloadName, Namespace: pod.Namespace}, workload) if err != nil { if !errors.IsNotFound(err) { - return fmt.Errorf("failed to get workload: %w", err) + return nil, fmt.Errorf("failed to get workload: %w", err) } - // find root owner references of pod - firstLevelOwnerRef := utils.FindFirstLevelOwnerReference(pod) // Create a new workload workload = &tfv1.TensorFusionWorkload{ @@ -242,25 +243,42 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod workload.Annotations[constants.DisableFeaturesAnnotation] = pod.Labels[constants.DisableFeaturesAnnotation] } - if firstLevelOwnerRef != nil { - workload.OwnerReferences = []metav1.OwnerReference{*firstLevelOwnerRef} + if controllerRef := metav1.GetControllerOf(pod); controllerRef != nil { + workload.OwnerReferences = []metav1.OwnerReference{*controllerRef} } if err := m.Client.Create(ctx, workload); err != nil { - return fmt.Errorf("failed to create workload: %w", err) + return nil, fmt.Errorf("failed to create workload: %w", err) + } + return workload, nil + } + + podControllerRef := metav1.GetControllerOf(pod) + workloadControllerRef := metav1.GetControllerOf(workload) + if !isSameControllerRef(podControllerRef, workloadControllerRef) || + !equality.Semantic.DeepEqual(workload.Spec, desiredSpec) { + patch := client.MergeFrom(workload.DeepCopy()) + if podControllerRef != nil { + workload.OwnerReferences = []metav1.OwnerReference{*podControllerRef} + } else { + workload.OwnerReferences = []metav1.OwnerReference{} } - return nil - } - - // Compare the entire spec at once - if !equality.Semantic.DeepEqual(workload.Spec, desiredSpec) { workload.Spec = desiredSpec - // TODO retry on conflict - if err := m.Client.Update(ctx, workload); err != nil { - return fmt.Errorf("failed to update workload: %w", err) + if err := m.Client.Patch(ctx, workload, patch); err != nil { + return nil, fmt.Errorf("failed to patch workload: %w", err) } } - return nil + return workload, nil +} + +func isSameControllerRef(a, b *metav1.OwnerReference) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return a.UID == b.UID } func (m *TensorFusionPodMutator) patchTFClient( diff --git a/internal/webhook/v1/pod_webhook_test.go b/internal/webhook/v1/pod_webhook_test.go index 374f2620..ac93f21c 100644 --- a/internal/webhook/v1/pod_webhook_test.go +++ b/internal/webhook/v1/pod_webhook_test.go @@ -595,6 +595,57 @@ var _ = Describe("TensorFusionPodMutator", func() { Expect(tfInfo.Profile.Qos).To(Equal(tfv1.QoSHigh)) Expect(*tfInfo.EnabledReplicas).To(Equal(int32(3))) }) + + It("should treat generateName as workload name if the pod has no controllerRef", func() { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + GenerateName: "test-name", + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + }, + }, + }, + } + tfInfo, _ := ParseTensorFusionInfo(ctx, k8sClient, pod) + Expect(tfInfo.WorkloadName).To(HavePrefix("test-name")) + }) + + It("should treat controller name as workload name if the pod has controllerRef", func() { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + GenerateName: "test-name", + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "test-rs", + UID: "rs-uid", + Controller: ptr.To(true), + }, + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + }, + }, + }, + } + tfInfo, _ := ParseTensorFusionInfo(ctx, k8sClient, pod) + Expect(tfInfo.WorkloadName).To(Equal("test-rs")) + }) }) Context("patchTFClient", func() { @@ -622,4 +673,137 @@ var _ = Describe("TensorFusionPodMutator", func() { Expect(len(patch)).To(BeNumerically(">=", 2)) }) }) + + Context("when handling workload", func() { + It("should update workload's controllerRef same with Pod's controllerRef", func() { + expectedRef := metav1.OwnerReference{ + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "test-rs", + UID: "rs-uid", + Controller: ptr.To(true), + } + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "test-name", + Labels: map[string]string{ + constants.TensorFusionEnabledLabelKey: "true", + }, + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + OwnerReferences: []metav1.OwnerReference{expectedRef}, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "main", + Image: "test-image", + }}, + }, + } + podBytes, err := json.Marshal(pod) + Expect(err).NotTo(HaveOccurred()) + + req := admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + Object: runtime.RawExtension{ + Raw: podBytes, + }, + Operation: admissionv1.Create, + Namespace: "default", + }, + } + + resp := mutator.Handle(ctx, req) + Expect(resp.Allowed).To(BeTrue()) + Expect(pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation]).To(BeEmpty()) + + Eventually(func(g Gomega) { + workload := &tfv1.TensorFusionWorkload{} + g.Expect(k8sClient.Get(ctx, + client.ObjectKey{ + Name: expectedRef.Name, + Namespace: "default", + }, workload)).To(Succeed()) + gotRef := metav1.GetControllerOfNoCopy(workload) + g.Expect(*gotRef).To(Equal(expectedRef)) + }).Should(Succeed()) + + newExpectedRef := metav1.OwnerReference{ + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "new-test-rs", + UID: "new-rs-uid", + Controller: ptr.To(true), + } + pod.OwnerReferences = []metav1.OwnerReference{newExpectedRef} + podBytes, err = json.Marshal(pod) + Expect(err).NotTo(HaveOccurred()) + + req = admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + Object: runtime.RawExtension{ + Raw: podBytes, + }, + Operation: admissionv1.Create, + Namespace: "default", + }, + } + + resp = mutator.Handle(ctx, req) + Expect(resp.Allowed).To(BeTrue()) + Expect(pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation]).To(BeEmpty()) + + Eventually(func(g Gomega) { + workload := &tfv1.TensorFusionWorkload{} + g.Expect(k8sClient.Get(ctx, + client.ObjectKey{ + Name: newExpectedRef.Name, + Namespace: "default", + }, workload)).To(Succeed()) + gotRef := metav1.GetControllerOfNoCopy(workload) + g.Expect(*gotRef).To(Equal(newExpectedRef)) + }).Should(Succeed()) + }) + + It("should add SetPendingOwnedWorkload annotation to pod when workload has no controllerRef", func() { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "test-name", + Labels: map[string]string{ + constants.TensorFusionEnabledLabelKey: "true", + }, + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "main", + Image: "test-image", + }}, + }, + } + podBytes, err := json.Marshal(pod) + Expect(err).NotTo(HaveOccurred()) + + req := admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + Object: runtime.RawExtension{ + Raw: podBytes, + }, + Operation: admissionv1.Create, + Namespace: "default", + }, + } + + resp := mutator.Handle(ctx, req) + Expect(resp.Allowed).To(BeTrue()) + annotation, found := lo.Find(resp.Patches, func(patch jsonpatch.JsonPatchOperation) bool { + return patch.Path == "/metadata/annotations/tensor-fusion.ai~1pending-owned-workload" + }) + Expect(found).To(BeTrue()) + Expect(annotation.Value).To(HavePrefix("test-name")) + }) + }) }) diff --git a/internal/webhook/v1/tf_parser.go b/internal/webhook/v1/tf_parser.go index dfd8fd19..b0fceb5f 100644 --- a/internal/webhook/v1/tf_parser.go +++ b/internal/webhook/v1/tf_parser.go @@ -10,8 +10,11 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" "github.com/NexusGPU/tensor-fusion/internal/utils" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -47,25 +50,22 @@ func ParseTensorFusionInfo( info.EnabledReplicas = &val32 } - workloadName, ok := pod.Annotations[constants.WorkloadKey] - if !ok { - // auto generate a workload with owner name - info.GenWorkload = true - owner := utils.FindFirstLevelOwnerReference(pod) - if owner == nil { + // Generate the workload name: + // If the Pod has no controller, use the Pod's name; + // if it is controlled by a Deployment, return the Deployment's name; + // otherwise, return the name of the first-level controller. + if controllerRef, err := getPodControllerRef(ctx, k8sClient, pod); err == nil { + if controllerRef != nil { + info.WorkloadName = controllerRef.Name + } else { if pod.Name == "" { info.WorkloadName = pod.GenerateName + "-" + utils.NewShortID(8) } else { info.WorkloadName = pod.Name } - info.PendingSetPodAsOwner = true - } else { - info.WorkloadName = owner.Name } } else { - // when workload is manually created, user can specify workload's replicas - // it remotely connects to lease connection worker when SelectWorker - info.WorkloadName = workloadName + return info, err } workloadProfileName, ok := pod.Annotations[constants.WorkloadProfileAnnotation] @@ -260,3 +260,34 @@ func handleDedicatedGPU(pod *corev1.Pod, workloadProfile *tfv1.WorkloadProfile) workloadProfile.Spec.Resources.Limits.Vram = resource.Vram return nil } + +func getPodControllerRef(ctx context.Context, c client.Client, pod *corev1.Pod) (*metav1.OwnerReference, error) { + podControllerRef := metav1.GetControllerOf(pod) + if podControllerRef == nil { + return nil, nil + } + + switch podControllerRef.Kind { + case "ReplicaSet": + { + // Special handling for Deployment resources + rs := &appsv1.ReplicaSet{} + if err := c.Get(ctx, client.ObjectKey{ + Namespace: pod.Namespace, + Name: podControllerRef.Name, + }, rs); err != nil { + if errors.IsNotFound(err) { + return podControllerRef, nil + } + return nil, fmt.Errorf("failed to get ReplicaSet: %w", err) + } + rsContollerRef := metav1.GetControllerOf(rs) + if rsContollerRef != nil && rsContollerRef.Kind == "Deployment" { + // If controlled by a Deployment, return the controllerRef of rs + return rsContollerRef, nil + } + } + } + + return podControllerRef, nil +}