Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func init() {
}
karpenterScheme.Register(&karpv1.NodeClaim{}, &karpv1.NodeClaimList{})
karpenterScheme.Register(&karpv1.NodePool{}, &karpv1.NodePoolList{})
karpenterScheme.AddToScheme(scheme)
utilruntime.Must(karpenterScheme.AddToScheme(scheme))
}

//nolint:gocyclo
Expand Down
5 changes: 0 additions & 5 deletions internal/utils/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ type TensorFusionInfo struct {
EnabledReplicas *int32
WorkloadName string
ContainerNames []string
GenWorkload bool

// Pod mutating webhook can not get Pod UID sometimes,
// thus need pod controller to set the owner reference
PendingSetPodAsOwner bool
}

func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo TensorFusionInfo) {
Expand Down
40 changes: 40 additions & 0 deletions internal/utils/owner_ref_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,43 @@ func FindFirstLevelOwnerReference(obj metav1.Object) *metav1.OwnerReference {
}
return &ownerRef
}

// FindRootControllerRef recursively finds the root controller reference for a given object (e.g. Pod).
func FindRootControllerRef(ctx context.Context, c client.Client, obj metav1.Object) (*metav1.OwnerReference, error) {
if metav1.GetControllerOfNoCopy(obj) == nil {
return nil, nil
}

namespace := obj.GetNamespace()
current := obj
for {
controllerRef := metav1.GetControllerOf(current)
if controllerRef == nil {
if rObj, ok := current.(runtime.Object); ok {
gvk := rObj.GetObjectKind().GroupVersionKind()
return metav1.NewControllerRef(current, gvk), nil
} else {
return nil, fmt.Errorf("not a runtime.Object")
}
}

unObj := &unstructured.Unstructured{}
unObj.SetAPIVersion(controllerRef.APIVersion)
unObj.SetKind(controllerRef.Kind)
err := c.Get(ctx, client.ObjectKey{Name: controllerRef.Name, Namespace: namespace}, unObj)
if err != nil {
// if not found, return controllerRef as root
if errors.IsNotFound(err) {
return controllerRef, nil
}
return nil, fmt.Errorf("get controller object: %w", err)
}

// Cast back to metav1.Object if possible
if metaObj, ok := any(unObj).(metav1.Object); ok {
current = metaObj
} else {
return nil, fmt.Errorf("unexpected type for controller object %s/%s", controllerRef.Kind, controllerRef.Name)
}
}
}
124 changes: 124 additions & 0 deletions internal/utils/owner_ref_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,127 @@ func TestFindRootOwnerReference(t *testing.T) {
require.Equal(t, "ReplicaSet", rootRef.Kind)
})
}

func TestFindRootControllerRef(t *testing.T) {
// Prepare the scheme
sch := runtime.NewScheme()
require.NoError(t, corev1.AddToScheme(sch))
require.NoError(t, appsv1.AddToScheme(sch))

t.Run("no controller returns nil", func(t *testing.T) {
pod := &corev1.Pod{
TypeMeta: metav1.TypeMeta{
APIVersion: "v1",
Kind: "Pod",
},
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
UID: "uid-pod",
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build()

rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.Nil(t, rootRef)
})

t.Run("hierarchy returns deployment", func(t *testing.T) {
controller := true
deployment := &appsv1.Deployment{
TypeMeta: metav1.TypeMeta{
APIVersion: "apps/v1",
Kind: "Deployment",
},
ObjectMeta: metav1.ObjectMeta{
Name: "mydeploy",
Namespace: "default",
UID: "uid-deploy",
},
}

rs := &appsv1.ReplicaSet{
TypeMeta: metav1.TypeMeta{
APIVersion: "apps/v1",
Kind: "ReplicaSet",
},
ObjectMeta: metav1.ObjectMeta{
Name: "myrs",
Namespace: "default",
UID: "uid-rs",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "Deployment",
Name: "mydeploy",
UID: deployment.UID,
Controller: &controller,
},
},
},
}

pod := &corev1.Pod{
TypeMeta: metav1.TypeMeta{
APIVersion: "v1",
Kind: "Pod",
},
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
UID: "uid-pod",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "ReplicaSet",
Name: "myrs",
UID: rs.UID,
Controller: &controller,
},
},
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod, rs, deployment).Build()

rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.NotNil(t, rootRef)
require.Equal(t, "mydeploy", rootRef.Name)
require.Equal(t, "Deployment", rootRef.Kind)
})

t.Run("missing controller returns last found ref", func(t *testing.T) {
controller := true
pod := &corev1.Pod{
TypeMeta: metav1.TypeMeta{
APIVersion: "v1",
Kind: "Pod",
},
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
UID: "uid-pod",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "ReplicaSet",
Name: "missing-rs",
UID: "uid-missing",
Controller: &controller,
},
},
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build()

rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.NotNil(t, rootRef)
require.Equal(t, "missing-rs", rootRef.Name)
require.Equal(t, "ReplicaSet", rootRef.Kind)
})
}
70 changes: 44 additions & 26 deletions internal/webhook/v1/pod_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/equality"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/strategicpatch"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand All @@ -53,7 +52,7 @@ func SetupPodWebhookWithManager(mgr ctrl.Manager, portAllocator *portallocator.P
webhookServer.Register("/mutate-v1-pod",
&admission.Webhook{
Handler: &TensorFusionPodMutator{
decoder: admission.NewDecoder(runtime.NewScheme()),
decoder: admission.NewDecoder(mgr.GetScheme()),
Client: mgr.GetClient(),
portAllocator: portAllocator,
},
Expand Down Expand Up @@ -122,19 +121,18 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque
podCounterAnnotationKey = podCounterKey
}

if tfInfo.PendingSetPodAsOwner {
pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation] = tfInfo.WorkloadName
}

pool := &tfv1.GPUPool{}
if err := m.Client.Get(ctx, client.ObjectKey{Name: tfInfo.Profile.PoolName}, pool); err != nil {
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("gpu pool(%s) does not exist", tfInfo.Profile.PoolName))
}

workload := &tfv1.TensorFusionWorkload{}
if tfInfo.GenWorkload {
if err := m.createOrUpdateWorkload(ctx, pod, &tfInfo, workload, pool); err != nil {
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("create tf workload: %w", err))
if workload, err := m.createOrUpdateWorkload(ctx, pod, &tfInfo, pool); err != nil {
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("create tf workload: %w", err))
} else {
// Pod mutating webhook can not get Pod UID,
// thus need pod controller to set the controller reference
if controllerRef := metav1.GetControllerOfNoCopy(workload); controllerRef == nil {
pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation] = tfInfo.WorkloadName
}
}

Expand Down Expand Up @@ -201,7 +199,11 @@ func (m *TensorFusionPodMutator) InjectDecoder(d admission.Decoder) error {
return nil
}

func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod *corev1.Pod, tfInfo *utils.TensorFusionInfo, workload *tfv1.TensorFusionWorkload, pool *tfv1.GPUPool) error {
func (m *TensorFusionPodMutator) createOrUpdateWorkload(
ctx context.Context,
pod *corev1.Pod,
tfInfo *utils.TensorFusionInfo,
pool *tfv1.GPUPool) (*tfv1.TensorFusionWorkload, error) {
// Create the desired spec for comparison
desiredSpec := tfv1.WorkloadProfileSpec{
Replicas: nil,
Expand All @@ -214,13 +216,12 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod
AutoScalingConfig: tfInfo.Profile.AutoScalingConfig,
}

workload := &tfv1.TensorFusionWorkload{}
err := m.Client.Get(ctx, client.ObjectKey{Name: tfInfo.WorkloadName, Namespace: pod.Namespace}, workload)
if err != nil {
if !errors.IsNotFound(err) {
return fmt.Errorf("failed to get workload: %w", err)
return nil, fmt.Errorf("failed to get workload: %w", err)
}
// find root owner references of pod
firstLevelOwnerRef := utils.FindFirstLevelOwnerReference(pod)

// Create a new workload
workload = &tfv1.TensorFusionWorkload{
Expand All @@ -242,25 +243,42 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod
workload.Annotations[constants.DisableFeaturesAnnotation] = pod.Labels[constants.DisableFeaturesAnnotation]
}

if firstLevelOwnerRef != nil {
workload.OwnerReferences = []metav1.OwnerReference{*firstLevelOwnerRef}
if controllerRef := metav1.GetControllerOf(pod); controllerRef != nil {
workload.OwnerReferences = []metav1.OwnerReference{*controllerRef}
}

if err := m.Client.Create(ctx, workload); err != nil {
return fmt.Errorf("failed to create workload: %w", err)
return nil, fmt.Errorf("failed to create workload: %w", err)
}
return workload, nil
}

podControllerRef := metav1.GetControllerOf(pod)
workloadControllerRef := metav1.GetControllerOf(workload)
if !isSameControllerRef(podControllerRef, workloadControllerRef) ||
!equality.Semantic.DeepEqual(workload.Spec, desiredSpec) {
patch := client.MergeFrom(workload.DeepCopy())
if podControllerRef != nil {
workload.OwnerReferences = []metav1.OwnerReference{*podControllerRef}
} else {
workload.OwnerReferences = []metav1.OwnerReference{}
}
return nil
}

// Compare the entire spec at once
if !equality.Semantic.DeepEqual(workload.Spec, desiredSpec) {
workload.Spec = desiredSpec
// TODO retry on conflict
if err := m.Client.Update(ctx, workload); err != nil {
return fmt.Errorf("failed to update workload: %w", err)
if err := m.Client.Patch(ctx, workload, patch); err != nil {
return nil, fmt.Errorf("failed to patch workload: %w", err)
}
}
return nil
return workload, nil
}

func isSameControllerRef(a, b *metav1.OwnerReference) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return a.UID == b.UID
}

func (m *TensorFusionPodMutator) patchTFClient(
Expand Down
Loading