From deb9169cde1bcd3a455efc2c295c70ad50f7984e Mon Sep 17 00:00:00 2001 From: Michael Burman Date: Fri, 10 Apr 2026 11:33:56 +0300 Subject: [PATCH 1/5] Remove first set of Kubernetes client mocks and replace them with fakeclient --- pkg/reconciliation/decommission_node_test.go | 7 - pkg/reconciliation/handler_reconcile_test.go | 29 +- pkg/reconciliation/handler_test.go | 240 ++++------ .../reconcile_datacenter_test.go | 208 ++++----- pkg/reconciliation/reconcile_fql_test.go | 21 +- pkg/reconciliation/reconcile_racks_test.go | 419 +++++++++--------- pkg/reconciliation/reconcile_services_test.go | 153 ++----- pkg/reconciliation/testing.go | 180 ++------ 8 files changed, 487 insertions(+), 770 deletions(-) diff --git a/pkg/reconciliation/decommission_node_test.go b/pkg/reconciliation/decommission_node_test.go index 1121e8804..7bae490f6 100644 --- a/pkg/reconciliation/decommission_node_test.go +++ b/pkg/reconciliation/decommission_node_test.go @@ -26,9 +26,6 @@ func TestRetryDecommissionNode(t *testing.T) { state := "UP" podIP := "192.168.101.11" - mockClient := mocks.NewClient(t) - rc.Client = mockClient - rc.Datacenter.SetCondition(api.DatacenterCondition{ Status: corev1.ConditionTrue, Type: api.DatacenterScalingDown, @@ -109,15 +106,11 @@ func TestRemoveResourcesWhenDone(t *testing.T) { podIP := "192.168.101.11" state := "LEFT" - mockClient := mocks.NewClient(t) - rc.Client = mockClient rc.Datacenter.SetCondition(api.DatacenterCondition{ Status: corev1.ConditionTrue, Type: api.DatacenterScalingDown, }) - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil) - labels := make(map[string]string) labels[api.CassNodeState] = stateDecommissioning diff --git a/pkg/reconciliation/handler_reconcile_test.go b/pkg/reconciliation/handler_reconcile_test.go index aa5d45452..30f77002b 100644 --- a/pkg/reconciliation/handler_reconcile_test.go +++ b/pkg/reconciliation/handler_reconcile_test.go @@ -7,8 +7,6 @@ import ( "time" api "github.com/k8ssandra/cass-operator/apis/cassandra/v1beta1" - "github.com/k8ssandra/cass-operator/pkg/mocks" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" controllers "github.com/k8ssandra/cass-operator/internal/controllers/cassandra" @@ -22,6 +20,7 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" "sigs.k8s.io/controller-runtime/pkg/reconcile" ) @@ -222,25 +221,17 @@ func TestReconcile_Error(t *testing.T) { s := scheme.Scheme s.AddKnownTypes(api.GroupVersion, dc) - mockClient := &mocks.Client{} - mockClient.On("Get", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(key client.ObjectKey) bool { - return key != client.ObjectKey{} - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - })). - Return(fmt.Errorf("some cryptic error")). - Once() + fakeClient := fake.NewClientBuilder(). + WithScheme(s). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, c client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return fmt.Errorf("some cryptic error") + }, + }). + Build() r := &controllers.CassandraDatacenterReconciler{ - Client: mockClient, + Client: fakeClient, Scheme: s, } diff --git a/pkg/reconciliation/handler_test.go b/pkg/reconciliation/handler_test.go index eb6ed4545..35123a86b 100644 --- a/pkg/reconciliation/handler_test.go +++ b/pkg/reconciliation/handler_test.go @@ -4,27 +4,26 @@ package reconciliation import ( + "context" "fmt" - "strings" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/reconcile" api "github.com/k8ssandra/cass-operator/apis/cassandra/v1beta1" "github.com/k8ssandra/cass-operator/pkg/dynamicwatch" - "github.com/k8ssandra/cass-operator/pkg/mocks" ) func TestCalculateReconciliationActions(t *testing.T) { @@ -57,7 +56,7 @@ func TestCalculateReconciliationActions(t *testing.T) { pod, } - fakeClient := fake.NewClientBuilder().WithStatusSubresource(rc.Datacenter, service).WithRuntimeObjects(trackObjects...).Build() + fakeClient := fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter, service).WithRuntimeObjects(trackObjects...).Build() rc.Client = fakeClient result, err := rc.CalculateReconciliationActions() @@ -73,41 +72,62 @@ func TestCalculateReconciliationActions_GetServiceError(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientGet(mockClient, fmt.Errorf("")) - k8sMockClientUpdate(mockClient, nil).Times(1) + getErr := fmt.Errorf("") + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, c client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if _, ok := obj.(*corev1.Service); ok { + return getErr + } + return c.Get(ctx, key, obj, opts...) + }, + }). + Build() _, err := rc.CalculateReconciliationActions() assert.Errorf(t, err, "Should have returned an error while calculating reconciliation actions") - - mockClient.AssertExpectations(t) } func TestCalculateReconciliationActions_FailedUpdate(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientUpdate(mockClient, fmt.Errorf("failed to update CassandraDatacenter with removed finalizers")) + updateErr := fmt.Errorf("failed to update CassandraDatacenter with removed finalizers") + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.UpdateOption) error { + if _, ok := obj.(*api.CassandraDatacenter); ok { + return updateErr + } + return c.Update(ctx, obj, opts...) + }, + }). + Build() _, err := rc.CalculateReconciliationActions() assert.Errorf(t, err, "Should have returned an error while calculating reconciliation actions") +} - mockClient.AssertExpectations(t) +func emptySecretWatcher(rc *ReconciliationContext) { + rc.SecretWatches = dynamicwatch.NewDynamicSecretWatches(rc.Client) } -func emptySecretWatcher(t *testing.T, rc *ReconciliationContext) { - mockClient := mocks.NewClient(t) - rc.SecretWatches = dynamicwatch.NewDynamicSecretWatches(mockClient) - k8sMockClientList(mockClient, nil). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*unstructured.UnstructuredList) - arg.Items = []unstructured.Unstructured{} - }) +func pvcProto(rc *ReconciliationContext) *corev1.PersistentVolumeClaim { + return &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: "server-data", + Namespace: rc.Datacenter.Namespace, + Labels: rc.Datacenter.GetDatacenterLabels(), + }, + } } // TestProcessDeletion_FailedDelete fails one step of the deletion process and should not cause @@ -117,57 +137,38 @@ func TestProcessDeletion_FailedDelete(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient rc.Datacenter.Spec.Size = 0 - - k8sMockClientList(mockClient, nil). - Run(func(args mock.Arguments) { - _, ok := args.Get(1).(*corev1.PodList) - if ok { - opts := listOptionsFromArg(args.Get(2)) - if opts == nil { - t.Fail() - return - } - if strings.HasPrefix(opts.FieldSelector.String(), "spec.volumes.persistentVolumeClaim.claimName") { - arg := args.Get(1).(*corev1.PodList) - arg.Items = []corev1.Pod{} - } else { - t.Fail() - } - return - } - arg := args.Get(1).(*corev1.PersistentVolumeClaimList) - arg.Items = []corev1.PersistentVolumeClaim{{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pvc-1", - }, - }} - }).Twice() - - k8sMockClientGet(mockClient, nil). - Run(func(args mock.Arguments) { - arg := args.Get(2).(*appsv1.StatefulSet) - arg.Spec.Replicas = ptr.To[int32](0) - }).Once() - - k8sMockClientDelete(mockClient, fmt.Errorf("")) - - emptySecretWatcher(t, rc) - - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil) // Update dc status - - rc.Datacenter.SetFinalizers([]string{"finalizer.cassandra.datastax.com"}) + rc.Datacenter.SetFinalizers([]string{api.Finalizer}) now := metav1.Now() rc.Datacenter.SetDeletionTimestamp(&now) + sts, err := newStatefulSetForCassandraDatacenter(nil, "default", rc.Datacenter, 0, imageRegistry) + assert.NoError(err) + pvc := pvcProto(rc) + deleteErr := fmt.Errorf("failed to delete pvc") + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter, sts, pvc). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + Delete: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.DeleteOption) error { + if _, ok := obj.(*corev1.PersistentVolumeClaim); ok { + return deleteErr + } + return c.Delete(ctx, obj, opts...) + }, + }). + Build() + emptySecretWatcher(rc) result, err := rc.CalculateReconciliationActions() assert.Errorf(err, "Should have returned an error while calculating reconciliation actions") assert.Equal(reconcile.Result{}, result, "Should not requeue request as error does cause requeue") - assert.True(len(rc.Datacenter.GetFinalizers()) > 0) - mockClient.AssertExpectations(t) + storedDC := &api.CassandraDatacenter{} + getErr := rc.Client.Get(rc.Ctx, client.ObjectKeyFromObject(rc.Datacenter), storedDC) + assert.NoError(getErr) + assert.True(controllerutil.ContainsFinalizer(storedDC, api.Finalizer)) } // TestProcessDeletion verifies the correct amount of calls to k8sClient on the deletion process @@ -176,55 +177,27 @@ func TestProcessDeletion(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientList(mockClient, nil). - Run(func(args mock.Arguments) { - _, ok := args.Get(1).(*corev1.PodList) - if ok { - opts := listOptionsFromArg(args.Get(2)) - if opts == nil { - t.Fail() - return - } - if strings.HasPrefix(opts.FieldSelector.String(), "spec.volumes.persistentVolumeClaim.claimName") { - arg := args.Get(1).(*corev1.PodList) - arg.Items = []corev1.Pod{} - } else { - t.Fail() - } - return - } - arg := args.Get(1).(*corev1.PersistentVolumeClaimList) - arg.Items = []corev1.PersistentVolumeClaim{{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pvc-1", - }, - }} - }).Twice() // ListPods - - k8sMockClientDelete(mockClient, nil) // Delete PVC - k8sMockClientUpdate(mockClient, nil) // Remove dc finalizer - k8sMockClientGet(mockClient, nil). - Run(func(args mock.Arguments) { - arg := args.Get(2).(*appsv1.StatefulSet) - arg.Spec.Replicas = ptr.To[int32](0) - }).Once() - - emptySecretWatcher(t, rc) - - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil) // Update dc status - - rc.Datacenter.SetFinalizers([]string{"finalizer.cassandra.datastax.com"}) + rc.Datacenter.SetFinalizers([]string{api.Finalizer}) now := metav1.Now() rc.Datacenter.SetDeletionTimestamp(&now) + sts, err := newStatefulSetForCassandraDatacenter(nil, "default", rc.Datacenter, 0, imageRegistry) + assert.NoError(err) + pvc := pvcProto(rc) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter, sts, pvc). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() + emptySecretWatcher(rc) result, err := rc.CalculateReconciliationActions() assert.NoError(err) assert.Equal(reconcile.Result{}, result, "Should not requeue request") - mockClient.AssertExpectations(t) + storedDC := &api.CassandraDatacenter{} + getErr := rc.Client.Get(rc.Ctx, client.ObjectKeyFromObject(rc.Datacenter), storedDC) + assert.True(apierrors.IsNotFound(getErr)) } // TestProcessDeletion_NoFinalizer verifies that the removal of finalizer means cass-operator will do nothing @@ -234,17 +207,12 @@ func TestProcessDeletion_NoFinalizer(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - now := metav1.Now() rc.Datacenter.SetDeletionTimestamp(&now) result, err := rc.CalculateReconciliationActions() assert.NoError(err) assert.Equal(reconcile.Result{}, result, "Should not requeue request") - - mockClient.AssertExpectations(t) } func TestAddFinalizer(t *testing.T) { @@ -252,18 +220,11 @@ func TestAddFinalizer(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - k8sMockClientUpdate(mockClient, nil).Times(1) // Add finalizer - err := rc.addFinalizer() assert.NoError(err) assert.True(controllerutil.ContainsFinalizer(rc.Datacenter, api.Finalizer)) - mockClient.AssertExpectations(t) // This should not add the finalizer again - mockClient = mocks.NewClient(t) - rc.Client = mockClient rc.Datacenter.Annotations = make(map[string]string) rc.Datacenter.Annotations[api.NoFinalizerAnnotation] = "true" controllerutil.RemoveFinalizer(rc.Datacenter, api.Finalizer) @@ -277,25 +238,20 @@ func TestConflictingDcNameOverride(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientList(mockClient, nil). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*api.CassandraDatacenterList) - arg.Items = []api.CassandraDatacenter{{ - ObjectMeta: metav1.ObjectMeta{ - Name: "dc1", - }, - Spec: api.CassandraDatacenterSpec{ - ClusterName: "cluster1", - DatacenterName: "CassandraDatacenter_example", - }, - Status: api.CassandraDatacenterStatus{ - DatacenterName: ptr.To[string]("CassandraDatacenter_example"), - }, - }} - }) + err := rc.Client.Create(rc.Ctx, &api.CassandraDatacenter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "dc1", + Namespace: rc.Datacenter.Namespace, + }, + Spec: api.CassandraDatacenterSpec{ + ClusterName: "cluster1", + DatacenterName: "CassandraDatacenter_example", + }, + Status: api.CassandraDatacenterStatus{ + DatacenterName: ptr.To[string]("CassandraDatacenter_example"), + }, + }) + assert.NoError(err) errs := rc.validateDatacenterNameConflicts() assert.NotEmpty(errs, "validateDatacenterNameConflicts should return an error as the datacenter name is already in use") diff --git a/pkg/reconciliation/reconcile_datacenter_test.go b/pkg/reconciliation/reconcile_datacenter_test.go index 815f8d1f2..ce4311882 100644 --- a/pkg/reconciliation/reconcile_datacenter_test.go +++ b/pkg/reconciliation/reconcile_datacenter_test.go @@ -4,12 +4,11 @@ package reconciliation import ( + "context" "fmt" - "strings" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" storagev1 "k8s.io/api/storage/v1" @@ -18,44 +17,24 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" api "github.com/k8ssandra/cass-operator/apis/cassandra/v1beta1" - "github.com/k8ssandra/cass-operator/pkg/mocks" ) func TestDeletePVCs(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientList(mockClient, nil). - Run(func(args mock.Arguments) { - _, ok := args.Get(1).(*corev1.PodList) - if ok { - opts := listOptionsFromArg(args.Get(2)) - if opts == nil { - t.Fail() - return - } - if strings.HasPrefix(opts.FieldSelector.String(), "spec.volumes.persistentVolumeClaim.claimName") { - arg := args.Get(1).(*corev1.PodList) - arg.Items = []corev1.Pod{} - } else { - t.Fail() - } - return - } - arg := args.Get(1).(*corev1.PersistentVolumeClaimList) - arg.Items = []corev1.PersistentVolumeClaim{{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pvc-1", - }, - }} - }).Twice() - - k8sMockClientDelete(mockClient, nil) + pvc := pvcProto(rc) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter, pvc). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() err := rc.deletePVCs() if err != nil { @@ -67,18 +46,21 @@ func TestDeletePVCs_FailedToList(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientList(mockClient, fmt.Errorf("failed to list PVCs for CassandraDatacenter")). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*corev1.PersistentVolumeClaimList) - arg.Items = []corev1.PersistentVolumeClaim{{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pvc-1", - }, - }} - }) + listErr := fmt.Errorf("failed to list PVCs for CassandraDatacenter") + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + List: func(ctx context.Context, c client.WithWatch, list client.ObjectList, opts ...client.ListOption) error { + if _, ok := list.(*corev1.PersistentVolumeClaimList); ok { + return listErr + } + return c.List(ctx, list, opts...) + }, + }). + Build() err := rc.deletePVCs() if err == nil { @@ -91,18 +73,21 @@ func TestDeletePVCs_PVCsNotFound(t *testing.T) { defer cleanupMockScr() assert := assert.New(t) - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientList(mockClient, errors.NewNotFound(schema.GroupResource{}, "name")). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*corev1.PersistentVolumeClaimList) - arg.Items = []corev1.PersistentVolumeClaim{{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pvc-1", - }, - }} - }) + notFoundErr := errors.NewNotFound(schema.GroupResource{}, "name") + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + List: func(ctx context.Context, c client.WithWatch, list client.ObjectList, opts ...client.ListOption) error { + if _, ok := list.(*corev1.PersistentVolumeClaimList); ok { + return notFoundErr + } + return c.List(ctx, list, opts...) + }, + }). + Build() assert.NoError(rc.deletePVCs()) } @@ -111,35 +96,22 @@ func TestDeletePVCs_FailedToDelete(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientList(mockClient, nil). - Run(func(args mock.Arguments) { - _, ok := args.Get(1).(*corev1.PodList) - if ok { - opts := listOptionsFromArg(args.Get(2)) - if opts == nil { - t.Fail() - return - } - if strings.HasPrefix(opts.FieldSelector.String(), "spec.volumes.persistentVolumeClaim.claimName") { - arg := args.Get(1).(*corev1.PodList) - arg.Items = []corev1.Pod{} - } else { - t.Fail() + pvc := pvcProto(rc) + deleteErr := fmt.Errorf("failed to delete") + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter, pvc). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + Delete: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.DeleteOption) error { + if _, ok := obj.(*corev1.PersistentVolumeClaim); ok { + return deleteErr } - return - } - arg := args.Get(1).(*corev1.PersistentVolumeClaimList) - arg.Items = []corev1.PersistentVolumeClaim{{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pvc-1", - }, - }} - }).Twice() - - k8sMockClientDelete(mockClient, fmt.Errorf("failed to delete")) + return c.Delete(ctx, obj, opts...) + }, + }). + Build() err := rc.deletePVCs() if err == nil { @@ -154,52 +126,24 @@ func TestDeletePVCs_FailedToDeleteBeingUsed(t *testing.T) { defer cleanupMockScr() assert := assert.New(t) - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientList(mockClient, nil). - Run(func(args mock.Arguments) { - _, ok := args.Get(1).(*corev1.PodList) - if ok { - opts := listOptionsFromArg(args.Get(2)) - if opts == nil { - t.Fail() - return - } - if strings.HasPrefix(opts.FieldSelector.String(), "spec.volumes.persistentVolumeClaim.claimName") { - arg := args.Get(1).(*corev1.PodList) - arg.Items = []corev1.Pod{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "pod-1", - }, - }, - } - } else { - t.Fail() - } - return - } - arg := args.Get(1).(*corev1.PersistentVolumeClaimList) - arg.Items = []corev1.PersistentVolumeClaim{{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pvc-1", - }, - }} - }).Twice() + pvc := pvcProto(rc) + pod := podWithPVC(rc, "pod-1", pvc.Name) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter, pvc, pod). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() err := rc.deletePVCs() assert.Error(err) - assert.EqualError(err, "PersistentVolumeClaim pvc-1 is still being used by a pod") + assert.EqualError(err, "PersistentVolumeClaim server-data is still being used by a pod") } func TestDeletePVCsSkip(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - rc.Datacenter.Annotations = map[string]string{ api.DeletePVCAnnotation: "false", } @@ -207,9 +151,6 @@ func TestDeletePVCsSkip(t *testing.T) { if err := rc.deletePVCs(); err != nil { t.Fatalf("deletePVCs should not have failed") } - - mockClient.AssertNotCalled(t, "List", mock.Anything, mock.Anything, mock.Anything) - mockClient.AssertNotCalled(t, "Delete", mock.Anything, mock.Anything) } func TestStorageExpansionNils(t *testing.T) { @@ -231,3 +172,22 @@ func TestStorageExpansionNils(t *testing.T) { require.NoError(err) require.True(supports) } + +func podWithPVC(rc *ReconciliationContext, podName string, pvcName string) *corev1.Pod { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: rc.Datacenter.Namespace, + }, + Spec: corev1.PodSpec{ + Volumes: []corev1.Volume{{ + Name: "server-data", + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: pvcName, + }, + }, + }}, + }, + } +} diff --git a/pkg/reconciliation/reconcile_fql_test.go b/pkg/reconciliation/reconcile_fql_test.go index 2ee95e4b6..cb4fb45e0 100644 --- a/pkg/reconciliation/reconcile_fql_test.go +++ b/pkg/reconciliation/reconcile_fql_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client/fake" ) var fqlEnabledConfig string = `{"cassandra-yaml": { @@ -38,13 +39,14 @@ var ( ) func setupPodList(t *testing.T, rc *ReconciliationContext) { + t.Helper() podIP := "192.168.101.11" - mockClient := mocks.NewClient(t) - pods := []corev1.Pod{{ ObjectMeta: metav1.ObjectMeta{ - Name: "test-sts-0", + Name: "test-sts-0", + Namespace: rc.Datacenter.Namespace, + Labels: rc.Datacenter.GetClusterLabels(), }, Status: corev1.PodStatus{ PodIP: podIP, @@ -55,13 +57,12 @@ func setupPodList(t *testing.T, rc *ReconciliationContext) { &pods[0], } - k8sMockClientList(mockClient, nil). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*corev1.PodList) - arg.Items = pods - }) - - rc.Client = mockClient + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter, &pods[0]). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() } func mockFeaturesEnabled(mockHttpClient *mocks.HttpClient) { diff --git a/pkg/reconciliation/reconcile_racks_test.go b/pkg/reconciliation/reconcile_racks_test.go index d2656be8d..2ce18ab80 100644 --- a/pkg/reconciliation/reconcile_racks_test.go +++ b/pkg/reconciliation/reconcile_racks_test.go @@ -43,6 +43,7 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/reconcile" @@ -240,7 +241,7 @@ func TestReconcileRacks_ReconcilePods(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = desiredStatefulSet.Labels[api.RackLabel] @@ -511,36 +512,6 @@ func TestReconcilePods(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientGet(mockClient, nil) - - // this mock will only pass if the pod is updated with the correct labels - mockClient.On("Update", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(obj *corev1.Pod) bool { - dc := api.CassandraDatacenter{ - ObjectMeta: metav1.ObjectMeta{ - Name: "cassandradatacenter-example", - Namespace: "default", - }, - Spec: api.CassandraDatacenterSpec{ - ClusterName: "cassandradatacenter-example-cluster", - }, - } - expected := dc.GetRackLabels("default") - expected[oplabels.ManagedByLabel] = oplabels.ManagedByLabelValue - - return reflect.DeepEqual(obj.GetLabels(), expected) - })). - Return(nil). - Once() - statefulSet, err := newStatefulSetForCassandraDatacenter( nil, "default", @@ -550,10 +521,23 @@ func TestReconcilePods(t *testing.T) { assert.NoErrorf(t, err, "error occurred creating statefulset") statefulSet.Status.Replicas = int32(1) + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: getStatefulSetPodNameForIdx(statefulSet, 0), + Namespace: rc.Datacenter.Namespace, + Labels: map[string]string{}, + }, + } + require.NoError(t, rc.Client.Create(rc.Ctx, pod)) + err = rc.ReconcilePods(statefulSet) assert.NoErrorf(t, err, "Should not have returned an error") - mockClient.AssertExpectations(t) + reconciledPod := &corev1.Pod{} + require.NoError(t, rc.Client.Get(rc.Ctx, client.ObjectKeyFromObject(pod), reconciledPod)) + expectedLabels := rc.Datacenter.GetRackLabels("default") + expectedLabels[oplabels.ManagedByLabel] = oplabels.ManagedByLabelValue + assert.Equal(t, expectedLabels, reconciledPod.GetLabels()) } func TestReconcilePods_WithVolumes(t *testing.T) { @@ -606,7 +590,7 @@ func TestReconcilePods_WithVolumes(t *testing.T) { pvc, } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(pod, pvc).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(pod, pvc).WithRuntimeObjects(trackObjects...).Build() err = rc.ReconcilePods(statefulSet) assert.NoErrorf(t, err, "Should not have returned an error") } @@ -651,16 +635,23 @@ func TestReconcileNextRack_CreateError(t *testing.T) { imageRegistry) assert.NoErrorf(t, err, "error occurred creating statefulset") - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientCreate(mockClient, fmt.Errorf("")) - k8sMockClientUpdate(mockClient, nil).Times(1) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + Create: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.CreateOption) error { + if _, ok := obj.(*appsv1.StatefulSet); ok { + return fmt.Errorf("") + } + return c.Create(ctx, obj, opts...) + }, + }). + Build() err = rc.ReconcileNextRack(statefulSet) - mockClient.AssertExpectations(t) - assert.Errorf(t, err, "Should have returned an error while calculating reconciliation actions") } @@ -739,7 +730,7 @@ func TestReconcileRacks(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(desiredStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -763,10 +754,20 @@ func TestReconcileRacks_GetStatefulsetError(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientGet(mockClient, fmt.Errorf("")) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, c client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if _, ok := obj.(*appsv1.StatefulSet); ok { + return fmt.Errorf("") + } + return c.Get(ctx, key, obj, opts...) + }, + }). + Build() var rackInfo []*RackInformation @@ -780,8 +781,6 @@ func TestReconcileRacks_GetStatefulsetError(t *testing.T) { result, err := rc.ReconcileAllRacks() - mockClient.AssertExpectations(t) - assert.Errorf(t, err, "Should have returned an error") t.Skip("FIXME - Skipping assertion") @@ -812,7 +811,7 @@ func TestReconcileRacks_WaitingForReplicas(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(desiredStatefulSet).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -849,7 +848,7 @@ func TestReconcileRacks_NeedMoreReplicas(t *testing.T) { preExistingStatefulSet, } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(preExistingStatefulSet).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(preExistingStatefulSet).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -891,7 +890,7 @@ func TestReconcileRacks_DoesntScaleDown(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(preExistingStatefulSet).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(preExistingStatefulSet).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -928,7 +927,7 @@ func TestReconcileRacks_NeedToPark(t *testing.T) { rc.Datacenter, } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(preExistingStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(preExistingStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -979,7 +978,7 @@ func TestReconcileRacks_AlreadyReconciled(t *testing.T) { desiredPdb, } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(desiredStatefulSet, rc.Datacenter, desiredPdb).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet, rc.Datacenter, desiredPdb).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -1056,7 +1055,7 @@ func TestReconcileRacks_FirstRackAlreadyReconciled(t *testing.T) { rc.Datacenter, } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(desiredStatefulSet, secondDesiredStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet, secondDesiredStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -1126,7 +1125,7 @@ func TestReconcileRacks_UpdateRackNodeCount(t *testing.T) { rc.Datacenter, } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(tt.args.statefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(tt.args.statefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() if err := rc.UpdateRackNodeCount(tt.args.statefulSet, tt.args.newNodeCount); (err != nil) != tt.wantErr { t.Errorf("updateRackNodeCount() error = %v, wantErr %v", err, tt.wantErr) @@ -1168,7 +1167,7 @@ func TestReconcileRacks_UpdateConfig(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(desiredStatefulSet, rc.Datacenter, desiredPdb).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet, rc.Datacenter, desiredPdb).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -1833,19 +1832,18 @@ func TestCleanupAfterScaling(t *testing.T) { defer cleanupMockScr() assert := assert.New(t) - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - var task *taskapi.CassandraTask - // 1. Create task - return ok - k8sMockClientCreate(rc.Client.(*mocks.Client), nil). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*taskapi.CassandraTask) - task = arg - }). - Times(1) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() r := rc.cleanupAfterScaling() + taskList := &taskapi.CassandraTaskList{} + require.NoError(t, rc.Client.List(rc.Ctx, taskList)) + require.Len(t, taskList.Items, 1) + task := taskList.Items[0] assert.Equal(result.Continue(), r, "expected result of result.Continue()") assert.Equal(taskapi.CommandCleanup, task.Spec.Jobs[0].Command) assert.Equal(0, len(rc.Datacenter.Status.TrackedTasks)) @@ -1859,36 +1857,26 @@ func TestCleanupAfterScalingWithTracker(t *testing.T) { // Setup annotation - mockClient := mocks.NewClient(t) - rc.Client = mockClient - metav1.SetMetaDataAnnotation(&rc.Datacenter.ObjectMeta, api.TrackCleanupTasksAnnotation, "true") - - var task *taskapi.CassandraTask - // 1. Create task - return ok - k8sMockClientCreate(rc.Client.(*mocks.Client), nil). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*taskapi.CassandraTask) - task = arg - }). - Times(1) - - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil).Once() + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() r := rc.cleanupAfterScaling() + taskKey := types.NamespacedName{Name: rc.Datacenter.Status.TrackedTasks[0].Name, Namespace: rc.Datacenter.Status.TrackedTasks[0].Namespace} + task := &taskapi.CassandraTask{} + require.NoError(t, rc.Client.Get(rc.Ctx, taskKey, task)) assert.Equal(taskapi.CommandCleanup, task.Spec.Jobs[0].Command) assert.Equal(result.RequeueSoon(10), r, "expected result of result.RequeueSoon(10)") assert.Equal(1, len(rc.Datacenter.Status.TrackedTasks)) - // 3. GET - return completed task - k8sMockClientGet(rc.Client.(*mocks.Client), nil). - Run(func(args mock.Arguments) { - arg := args.Get(2).(*taskapi.CassandraTask) - task.DeepCopyInto(arg) - timeNow := metav1.Now() - arg.Status.CompletionTime = &timeNow - }).Once() - // 4. Patch to datacenter status - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil).Once() + + timeNow := metav1.Now() + task.Status.CompletionTime = &timeNow + require.NoError(t, rc.Client.Update(rc.Ctx, task)) + r = rc.cleanupAfterScaling() assert.Equal(result.Continue(), r, "expected result of result.Continue()") assert.Equal(0, len(rc.Datacenter.Status.TrackedTasks)) @@ -1900,21 +1888,20 @@ func TestCleanupAfterScalingWithParallelAnnotation(t *testing.T) { defer cleanupMockScr() assert := assert.New(t) - mockClient := mocks.NewClient(t) - rc.Client = mockClient _ = rc.CalculateRackInformation() metav1.SetMetaDataAnnotation(&rc.Datacenter.ObjectMeta, api.EnableParallelCleanupWithinRackAnnotation, "true") - - var task *taskapi.CassandraTask - // 1. Create task - return ok - k8sMockClientCreate(rc.Client.(*mocks.Client), nil). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*taskapi.CassandraTask) - task = arg - }). - Times(1) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() r := rc.cleanupAfterScaling() + taskList := &taskapi.CassandraTaskList{} + require.NoError(t, rc.Client.List(rc.Ctx, taskList)) + require.Len(t, taskList.Items, 1) + task := taskList.Items[0] assert.Equal(result.Continue(), r, "expected result of result.Continue()") assert.Equal(taskapi.CommandCleanup, task.Spec.Jobs[0].Command) assert.Equal(0, len(rc.Datacenter.Status.TrackedTasks)) @@ -2010,16 +1997,6 @@ func TestFailedStart(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - done := make(chan struct{}) - k8sMockClientDelete(mockClient, nil).Once().Run(func(mock.Arguments) { close(done) }) - - // Patch labelStarting, lastNodeStarted.. - k8sMockClientPatch(mockClient, nil).Once() - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil).Twice() - res := &http.Response{ StatusCode: http.StatusInternalServerError, Body: io.NopCloser(strings.NewReader("OK")), @@ -2034,19 +2011,25 @@ func TestFailedStart(t *testing.T) { Return(res, nil). Once() - client := httphelper.NodeMgmtClient{ + nodeMgmtClient := httphelper.NodeMgmtClient{ Client: mockHttpClient, Log: rc.ReqLogger, Protocol: "http", } - rc.NodeMgmtClient = client + rc.NodeMgmtClient = nodeMgmtClient epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, } pod := makeReloadTestPod() + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(runtimeObjectHelper(rc, nil, []*corev1.Pod{pod})...). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() fakeRecorder := record.NewFakeRecorder(5) rc.Recorder = fakeRecorder @@ -2055,19 +2038,18 @@ func TestFailedStart(t *testing.T) { // The start is async method, so the error is not returned here assert.Nil(t, err) - select { - case <-done: - case <-time.After(2 * time.Second): - assert.Fail(t, "No pod delete occurred") - } - - // mockClient.AssertExpectations(t) - // mockHttpClient.AssertExpectations(t) + require.Eventually(t, func() bool { + currentPod := &corev1.Pod{} + err := rc.Client.Get(rc.Ctx, client.ObjectKeyFromObject(pod), currentPod) + return apierrors.IsNotFound(err) + }, 2*time.Second, 20*time.Millisecond, "expected pod delete after failed start") + require.Eventually(t, func() bool { + return len(rc.Datacenter.Status.FailedStarts) == 1 && rc.Datacenter.Status.FailedStarts[0] == pod.Name + }, 2*time.Second, 20*time.Millisecond, "expected failed start status update") close(fakeRecorder.Events) // Should have 2 events, one to indicate Cassandra is starting, one to indicate it failed to start assert.Equal(t, 2, len(fakeRecorder.Events)) - assert.Equal(t, rc.Datacenter.Status.FailedStarts[0], pod.Name) } func TestStartBootstrappedNodes(t *testing.T) { @@ -2263,8 +2245,25 @@ func TestStartBootstrappedNodes(t *testing.T) { } } - mockClient := mocks.NewClient(t) - rc.Client = mockClient + trackObjects := []runtime.Object{ + rc.Datacenter, + // rc.statefulSets, + // rc.dcPods, + } + + for _, sts := range rc.statefulSets { + trackObjects = append(trackObjects, sts) + } + for _, pod := range rc.dcPods { + trackObjects = append(trackObjects, pod) + } + + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(trackObjects...). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() expectedStartCount := 0 for i, rackPods := range tt.racks { @@ -2284,12 +2283,6 @@ func TestStartBootstrappedNodes(t *testing.T) { }() if tt.wantNotReady { - // mock the calls in labelServerPodStarting: - // patch the pod: pod.Labels[api.CassNodeState] = stateStarting - k8sMockClientPatch(mockClient, nil).Times(expectedStartCount) - // patch the dc status: dc.Status.LastServerNodeStarted = metav1.Now() - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil).Times(expectedStartCount) - res := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("OK")), @@ -2330,6 +2323,8 @@ func TestStartBootstrappedNodes(t *testing.T) { } } + assertStartingPodsAndStatusPatched(t, rc, expectedStartCount, false) + fakeRecorder := rc.Recorder.(*record.FakeRecorder) close(fakeRecorder.Events) if assert.Lenf(t, fakeRecorder.Events, len(tt.wantEvents), "expected %d events, got %d", len(tt.wantEvents), len(fakeRecorder.Events)) { @@ -2339,8 +2334,6 @@ func TestStartBootstrappedNodes(t *testing.T) { } assert.ElementsMatch(t, tt.wantEvents, gotEvents) } - - mockClient.AssertExpectations(t) }) } } @@ -2603,17 +2596,15 @@ func TestReconciliationContext_startAllNodes(t *testing.T) { } } - mockClient := mocks.NewClient(t) - rc.Client = mockClient + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.dcPods)...). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() done := make(chan struct{}) if tt.wantNotReady { - // mock the calls in labelServerPodStarting: - // patch the pod: pod.Labels[api.CassNodeState] = stateStarting - k8sMockClientPatch(mockClient, nil) - // patch the dc status: dc.Status.LastServerNodeStarted = metav1.Now() - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil) - res := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("OK")), @@ -2654,6 +2645,8 @@ func TestReconciliationContext_startAllNodes(t *testing.T) { } } + assertStartingPodsAndStatusPatched(t, rc, len(tt.wantEvents), false) + fakeRecorder := rc.Recorder.(*record.FakeRecorder) close(fakeRecorder.Events) if assert.Lenf(t, fakeRecorder.Events, len(tt.wantEvents), "expected %d events, got %d", len(tt.wantEvents), len(fakeRecorder.Events)) { @@ -2663,8 +2656,6 @@ func TestReconciliationContext_startAllNodes(t *testing.T) { } assert.Equal(t, tt.wantEvents, gotEvents) } - - mockClient.AssertExpectations(t) }) } } @@ -2754,18 +2745,15 @@ func TestReconciliationContext_startAllNodes_onlyRackInformation(t *testing.T) { rc.dcPods = append(rc.dcPods, p) } } - - mockClient := mocks.NewClient(t) - rc.Client = mockClient + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.dcPods)...). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() done := make(chan struct{}) if tt.wantNotReady { - // mock the calls in labelServerPodStarting: - // patch the pod: pod.Labels[api.CassNodeState] = stateStarting - k8sMockClientPatch(mockClient, nil) - // patch the dc status: dc.Status.LastServerNodeStarted = metav1.Now() - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil) - res := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("OK")), @@ -2781,14 +2769,13 @@ func TestReconciliationContext_startAllNodes_onlyRackInformation(t *testing.T) { Once(). Run(func(mock.Arguments) { close(done) }) - client := httphelper.NodeMgmtClient{ + nodeMgmtClient := httphelper.NodeMgmtClient{ Client: mockHttpClient, Log: rc.ReqLogger, Protocol: "http", } - rc.NodeMgmtClient = client + rc.NodeMgmtClient = nodeMgmtClient } - epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, } @@ -2797,7 +2784,6 @@ func TestReconciliationContext_startAllNodes_onlyRackInformation(t *testing.T) { assert.NoError(t, err) assert.Equalf(t, tt.wantNotReady, gotNotReady, "expected not ready to be %v", tt.wantNotReady) - if tt.wantNotReady { select { case <-done: @@ -2806,6 +2792,8 @@ func TestReconciliationContext_startAllNodes_onlyRackInformation(t *testing.T) { } } + assertStartingPodsAndStatusPatched(t, rc, len(tt.wantEvents), false) + fakeRecorder := rc.Recorder.(*record.FakeRecorder) close(fakeRecorder.Events) if assert.Lenf(t, fakeRecorder.Events, len(tt.wantEvents), "expected %d events, got %d", len(tt.wantEvents), len(fakeRecorder.Events)) { @@ -2815,8 +2803,6 @@ func TestReconciliationContext_startAllNodes_onlyRackInformation(t *testing.T) { } assert.Equal(t, tt.wantEvents, gotEvents) } - - mockClient.AssertExpectations(t) }) } } @@ -2935,8 +2921,12 @@ func TestStartOneNodePerRack(t *testing.T) { } } - mockClient := mocks.NewClient(t) - rc.Client = mockClient + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.dcPods)...). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() done := make(chan struct{}) @@ -2962,23 +2952,6 @@ func TestStartOneNodePerRack(t *testing.T) { Protocol: "http", } rc.NodeMgmtClient = client - - // mock the calls in labelServerPodStarting: - // patch the pod: pod.Labels[api.CassNodeState] = stateStarting - k8sMockClientPatch(mockClient, nil) - // get the status client - // patch the dc status: dc.Status.LastServerNodeStarted = metav1.Now() - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil) - - // We need to mock the hasAdditionalSeeds call - // Mock the Get calls for EndpointSlices - // Three calls for the three potential slices (IPv4, IPv6, FQDN) - - if tt.seedCount < 1 { - // There's additional checks here, for fetching the possible additional-seeds (the GET) and pre-adding a seed label - k8sMockClientGet(mockClient, nil).Times(3) - k8sMockClientPatch(mockClient, nil) - } } epData := httphelper.CassMetadataEndpoints{ @@ -2997,6 +2970,11 @@ func TestStartOneNodePerRack(t *testing.T) { assert.NoError(t, err) assert.Equalf(t, tt.wantNotReady, gotNotReady, "expected not ready to be %v", tt.wantNotReady) + expectedStartingCount := 0 + if tt.wantNotReady { + expectedStartingCount = 1 + } + assertStartingPodsAndStatusPatched(t, rc, expectedStartingCount, tt.wantNotReady && tt.seedCount < 1) }) } } @@ -3060,19 +3038,6 @@ func TestStartOneNodePerRackFailed(t *testing.T) { } } - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - mockHttpClient := mocks.NewHttpClient(t) - k8sMockClientGet(mockClient, nil).Times(3) - - client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } - rc.NodeMgmtClient = client - epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, } @@ -3453,18 +3418,19 @@ func TestSetConditionStatus(t *testing.T) { defer cleanupMockScr() assert := assert.New(t) - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientStatusUpdate(mockClient.Status().(*mocks.SubResourceClient), nil).Times(2) assert.NoError(rc.setConditionStatus(api.DatacenterHealthy, corev1.ConditionTrue)) assert.Equal(corev1.ConditionTrue, rc.Datacenter.GetConditionStatus(api.DatacenterHealthy)) + dc := &api.CassandraDatacenter{} + assert.NoError(rc.Client.Get(rc.Ctx, client.ObjectKeyFromObject(rc.Datacenter), dc)) + assert.Equal(corev1.ConditionTrue, dc.GetConditionStatus(api.DatacenterHealthy)) val, err := monitoring.GetMetricValue("cass_operator_datacenter_status", map[string]string{"datacenter": rc.Datacenter.DatacenterName(), "condition": string(api.DatacenterHealthy)}) assert.NoError(err) assert.Equal(float64(1), val) assert.NoError(rc.setConditionStatus(api.DatacenterHealthy, corev1.ConditionFalse)) assert.Equal(corev1.ConditionFalse, rc.Datacenter.GetConditionStatus(api.DatacenterHealthy)) + assert.NoError(rc.Client.Get(rc.Ctx, client.ObjectKeyFromObject(rc.Datacenter), dc)) + assert.Equal(corev1.ConditionFalse, dc.GetConditionStatus(api.DatacenterHealthy)) val, err = monitoring.GetMetricValue("cass_operator_datacenter_status", map[string]string{"datacenter": rc.Datacenter.DatacenterName(), "condition": string(api.DatacenterHealthy)}) assert.NoError(err) assert.Equal(float64(0), val) @@ -3475,11 +3441,6 @@ func TestDatacenterStatus(t *testing.T) { defer cleanupMockScr() assert := assert.New(t) - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientStatusPatch(mockClient.Status().(*mocks.SubResourceClient), nil).Once() - k8sMockClientStatusUpdate(mockClient.Status().(*mocks.SubResourceClient), nil).Times(2) assert.NoError(rc.setConditionStatus(api.DatacenterRequiresUpdate, corev1.ConditionTrue)) // This uses one StatusUpdate call rc.Datacenter.Status.ObservedGeneration = 0 rc.Datacenter.Generation = 1 @@ -3517,7 +3478,7 @@ func TestDatacenterPods(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = "default" @@ -3562,7 +3523,7 @@ func TestDatacenterPodsOldLabels(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = "default" @@ -3616,7 +3577,7 @@ func TestDatacenterPodsNoDualFetch(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = "default" @@ -3662,7 +3623,7 @@ func TestCheckRackLabels(t *testing.T) { desiredStatefulSet, rc.Datacenter, } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() rc.statefulSets = []*appsv1.StatefulSet{desiredStatefulSet} @@ -3704,7 +3665,7 @@ func TestCheckPodsReadyAllStarted(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = desiredStatefulSet.Labels[api.RackLabel] @@ -4187,7 +4148,7 @@ func TestRefreshSeeds(t *testing.T) { metav1.SetMetaDataLabel(&mp.ObjectMeta, api.SeedNodeLabel, "true") trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, } @@ -4297,3 +4258,53 @@ func TestRefreshSeeds(t *testing.T) { mockHttpClient.AssertExpectations(t) }) } + +func runtimeObjectHelper(rc *ReconciliationContext, statefulSets []*appsv1.StatefulSet, pods []*corev1.Pod) []runtime.Object { + trackObjects := []runtime.Object{rc.Datacenter} + for _, statefulSet := range statefulSets { + if statefulSet != nil { + trackObjects = append(trackObjects, statefulSet) + } + } + for _, pod := range pods { + if pod != nil { + trackObjects = append(trackObjects, pod) + } + } + return trackObjects +} + +func assertStartingPodsAndStatusPatched(t *testing.T, rc *ReconciliationContext, expectedStartingCount int, requireSeedLabel bool) { + t.Helper() + + require.Eventually(t, func() bool { + podList := &corev1.PodList{} + if err := rc.Client.List(rc.Ctx, podList); err != nil { + return false + } + + startingCount := 0 + for _, pod := range podList.Items { + if pod.Labels[api.CassNodeState] != stateStarting { + continue + } + startingCount++ + if requireSeedLabel && pod.Labels[api.SeedNodeLabel] != "true" { + return false + } + } + + dc := &api.CassandraDatacenter{} + if err := rc.Client.Get(rc.Ctx, client.ObjectKeyFromObject(rc.Datacenter), dc); err != nil { + return false + } + + if startingCount != expectedStartingCount { + return false + } + if expectedStartingCount > 0 { + return !dc.Status.LastServerNodeStarted.IsZero() + } + return dc.Status.LastServerNodeStarted.IsZero() + }, 2*time.Second, 20*time.Millisecond, "expected persisted starting pod and datacenter status updates") +} diff --git a/pkg/reconciliation/reconcile_services_test.go b/pkg/reconciliation/reconcile_services_test.go index f74843556..10413538c 100644 --- a/pkg/reconciliation/reconcile_services_test.go +++ b/pkg/reconciliation/reconcile_services_test.go @@ -6,17 +6,16 @@ package reconciliation import ( "context" "fmt" - "reflect" + "maps" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" - "github.com/k8ssandra/cass-operator/pkg/mocks" "github.com/k8ssandra/cass-operator/pkg/utils" discoveryv1 "k8s.io/api/discovery/v1" ) @@ -26,14 +25,6 @@ func TestReconcileHeadlessService(t *testing.T) { defer cleanupMockScr() recResult := rc.CheckHeadlessServices() - - // kind of weird to check this path we don't want in a test, but - // it's useful to see what the error is - if recResult.Completed() { - _, err := recResult.Output() - assert.NoErrorf(t, err, "Should not have returned an error") - } - assert.False(t, recResult.Completed(), "Reconcile loop should not be completed") } @@ -41,111 +32,49 @@ func TestReconcileHeadlessService_UpdateLabelsAndAnnotations(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - // place holder for service label maps - svcLabelMap := make(map[string]map[string]string) - // place holder for service annotation maps - svcAnnotationMap := make(map[string]map[string]string) - - k8sMockClientGet(mockClient, nil). - Times(4) - k8sMockClientUpdate(mockClient, nil). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*corev1.Service) - // store service labels - svcLabelMap[arg.GetName()] = arg.GetLabels() - // store svc annotations - svcAnnotationMap[arg.GetName()] = arg.GetAnnotations() - }). - Times(4) - - // Check the service should populate labels and annotations recResult := rc.CheckHeadlessServices() - - // kind of weird to check this path we don't want in a test, but - // it's useful to see what the error is - if recResult.Completed() { - _, err := recResult.Output() - assert.NoErrorf(t, err, "Should not have returned an error") - } - assert.False(t, recResult.Completed(), "Reconcile loop should not be completed") - // Mock the Datacenter Service to have additional labels dcSvcName := rc.Datacenter.GetDatacenterServiceName() - assert.Containsf(t, svcLabelMap, dcSvcName, "Expected Datacenter service to be in service map. Expected name: %s, service map:\n%v\n", dcSvcName, svcLabelMap) - dcSvcLabels := svcLabelMap[dcSvcName] + dcSvc := &corev1.Service{} + err := rc.Client.Get(rc.Ctx, types.NamespacedName{Name: dcSvcName, Namespace: rc.Datacenter.Namespace}, dcSvc) + assert.NoError(t, err) + + dcSvcLabels := dcSvc.GetLabels() dcSvcLabels["AddKey1"] = "Value1" dcSvcLabels["AddKey2"] = "Value2" - dcSvcAnnotations := svcAnnotationMap[dcSvcName] + dcSvc.SetLabels(dcSvcLabels) + + dcSvcAnnotations := dcSvc.GetAnnotations() dcSvcAnnotations["AddAnnotation1"] = "AddValue1" dcSvcAnnotations["AddAnnotation2"] = "AddValue2" - // In DC Additional Labels, add a label, change a label value and delete a label + dcSvc.SetAnnotations(dcSvcAnnotations) + assert.NoError(t, rc.Client.Update(rc.Ctx, dcSvc)) + rc.Datacenter.Spec.AdditionalServiceConfig.DatacenterService.Labels = map[string]string{"AddKey1": "ChangeValue1", "AddKey3": "Value3"} - updatedDcSvcLabels := make(map[string]string) - // copy current labels into updated labels - for k, v := range dcSvcLabels { - updatedDcSvcLabels[k] = v - } + updatedDcSvcLabels := maps.Clone(dcSvcLabels) delete(updatedDcSvcLabels, "AddKey2") updatedDcSvcLabels["AddKey1"] = "ChangeValue1" updatedDcSvcLabels["AddKey3"] = "Value3" - // In DC Additional Annotations, add an annotation, change an annotation value and delete an annotation + rc.Datacenter.Spec.AdditionalServiceConfig.DatacenterService.Annotations = map[string]string{"AddAnnotation1": "ChangeAnnotation1", "AddAnnotation3": "AddValue3"} - updatedDcSvcAnnotations := make(map[string]string) - // copy current annotations into updated annotations - for k, v := range dcSvcAnnotations { - updatedDcSvcAnnotations[k] = v - } + updatedDcSvcAnnotations := maps.Clone(dcSvcAnnotations) delete(updatedDcSvcAnnotations, "AddAnnotation2") updatedDcSvcAnnotations["AddAnnotation1"] = "ChangeAnnotation1" updatedDcSvcAnnotations["AddAnnotation3"] = "AddValue3" - // resource hash annotation will change, so exclude it from the comparison delete(updatedDcSvcAnnotations, utils.ResourceHashAnnotationKey) - k8sMockClientGet(mockClient, nil). - Run(func(args mock.Arguments) { - svcName := args.Get(1).(types.NamespacedName) - arg := args.Get(2).(*corev1.Service) - if svcName.Name == dcSvcName { - // set the expected service labels - arg.SetLabels(dcSvcLabels) - // set the expected service annotations - arg.SetAnnotations(dcSvcAnnotations) - } - }). - Times(4) - k8sMockClientUpdate(mockClient, nil). - Run(func(args mock.Arguments) { - arg := args.Get(1).(*corev1.Service) - // store service labels - svcLabelMap[arg.GetName()] = arg.GetLabels() - // store service annotations - svcAnnotationMap[arg.GetName()] = arg.GetAnnotations() - // verify additional labels and annotations are added for the Datacenter Service - if arg.GetName() == dcSvcName { - assert.Truef(t, reflect.DeepEqual(arg.GetLabels(), updatedDcSvcLabels), "Datacenter Service Labels do not match. Expected:\n%v\nObserved:\n%v\n", updatedDcSvcLabels, arg.GetLabels()) - // resource hash annotation will change, so exclude it from the comparison - observedAnnotations := arg.GetAnnotations() - delete(observedAnnotations, utils.ResourceHashAnnotationKey) - assert.Truef(t, reflect.DeepEqual(arg.GetAnnotations(), updatedDcSvcAnnotations), "Datacenter Service Annotations do not match. Expected:\n%v\nObserved:\n%v\n", updatedDcSvcAnnotations, observedAnnotations) - } - }). - Times(4) - - // re-populate labels and annotations recResult = rc.CheckHeadlessServices() + assert.False(t, recResult.Completed(), "Reconcile loop should not be completed") - // kind of weird to check this path we don't want in a test, but - // it's useful to see what the error is - if recResult.Completed() { - _, err := recResult.Output() - assert.NoErrorf(t, err, "Should not have returned an error") - } + updatedSvc := &corev1.Service{} + err = rc.Client.Get(rc.Ctx, types.NamespacedName{Name: dcSvcName, Namespace: rc.Datacenter.Namespace}, updatedSvc) + assert.NoError(t, err) + assert.Equal(t, updatedDcSvcLabels, updatedSvc.GetLabels()) - assert.False(t, recResult.Completed(), "Reconcile loop should not be completed") + observedAnnotations := updatedSvc.GetAnnotations() + delete(observedAnnotations, utils.ResourceHashAnnotationKey) + assert.Equal(t, updatedDcSvcAnnotations, observedAnnotations) } func TestCreateHeadlessService(t *testing.T) { @@ -167,38 +96,38 @@ func TestCreateHeadlessService(t *testing.T) { } func TestCreateHeadlessService_ClientReturnsError(t *testing.T) { - // skipped because mocking Status() call and response is very tricky - t.Skip() rc, svc, cleanupMockScr := setupTest() defer cleanupMockScr() - mockClient := mocks.NewClient(t) - rc.Client = mockClient - - k8sMockClientCreate(mockClient, fmt.Errorf("")) - k8sMockClientUpdate(mockClient, nil).Times(1) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(rc.Datacenter). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + WithInterceptorFuncs(interceptor.Funcs{ + Create: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.CreateOption) error { + if _, ok := obj.(*corev1.Service); ok { + return fmt.Errorf("") + } + return c.Create(ctx, obj, opts...) + }, + }). + Build() rc.Services = []*corev1.Service{svc} recResult := rc.CreateHeadlessServices() - // kind of weird to check this path we don't want in a test, but - // it's useful to see what the error is - if recResult.Completed() { - _, err := recResult.Output() - assert.NoErrorf(t, err, "Should not have returned an error") - } - assert.True(t, recResult.Completed(), "Reconcile loop should be completed") - - mockClient.AssertExpectations(t) + _, err := recResult.Output() + assert.Error(t, err, "Should have returned the service creation error") } func TestEndpointSliceControllerIntegration(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - fakeClient := fake.NewClientBuilder().WithRuntimeObjects(rc.Datacenter).Build() + fakeClient := fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithRuntimeObjects(rc.Datacenter).Build() rc.Client = fakeClient rc.Datacenter.Spec.AdditionalSeeds = []string{ diff --git a/pkg/reconciliation/testing.go b/pkg/reconciliation/testing.go index b91010bcf..4bdfa6dcf 100644 --- a/pkg/reconciliation/testing.go +++ b/pkg/reconciliation/testing.go @@ -24,7 +24,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" - "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" @@ -34,12 +33,16 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" api "github.com/k8ssandra/cass-operator/apis/cassandra/v1beta1" + taskapi "github.com/k8ssandra/cass-operator/apis/control/v1alpha1" "github.com/k8ssandra/cass-operator/pkg/httphelper" "github.com/k8ssandra/cass-operator/pkg/images" "github.com/k8ssandra/cass-operator/pkg/mocks" discoveryv1 "k8s.io/api/discovery/v1" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" ) +const podPVCClaimNameField = "spec.volumes.persistentVolumeClaim.claimName" + func newTestImageRegistry() images.ImageRegistry { imageConfigFile := filepath.Join("..", "..", "tests", "testdata", "image_config_parsing.yaml") registry, err := images.NewImageRegistry(imageConfigFile) @@ -121,11 +124,13 @@ func CreateMockReconciliationContext( storageClass, } - s := scheme.Scheme - setupScheme(s) - // s.AddKnownTypes(api.GroupVersion, cassandraDatacenter) - - fakeClient := fake.NewClientBuilder().WithStatusSubresource(cassandraDatacenter).WithRuntimeObjects(trackObjects...).Build() + s := setupScheme(runtime.NewScheme()) + fakeClient := fake.NewClientBuilder(). + WithScheme(s). + WithStatusSubresource(cassandraDatacenter). + WithRuntimeObjects(trackObjects...). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() request := &reconcile.Request{ NamespacedName: types.NamespacedName{ @@ -174,158 +179,29 @@ func setupTest() (*ReconciliationContext, *corev1.Service, func()) { return rc, service, cleanupMockScr } -func k8sMockClientGet(mockClient *mocks.Client, returnArg interface{}) *mock.Call { - return mockClient.On("Get", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(key client.ObjectKey) bool { - return key != client.ObjectKey{} - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - })). - Return(returnArg). - Once() -} - -func k8sMockClientUpdate(mockClient *mocks.Client, returnArg interface{}) *mock.Call { - return mockClient.On("Update", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - })). - Return(returnArg). - Once() -} - -func k8sMockClientPatch(mockClient *mocks.Client, returnArg interface{}) *mock.Call { - return mockClient.On("Patch", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - }), - mock.MatchedBy( - func(patch client.Patch) bool { - return patch != nil - })). - Return(returnArg). - Once() -} - -func k8sMockClientStatusPatch(mockClient *mocks.SubResourceClient, returnArg interface{}) *mock.Call { - return mockClient.On("Patch", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - }), - mock.MatchedBy( - func(patch client.Patch) bool { - return patch != nil - })). - Return(returnArg). - Once() -} - -func k8sMockClientStatusUpdate(mockClient *mocks.SubResourceClient, returnArg interface{}) *mock.Call { - return mockClient.On("Update", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - })). - Return(returnArg). - Once() -} - -func k8sMockClientCreate(mockClient *mocks.Client, returnArg interface{}) *mock.Call { - return mockClient.On("Create", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - })). - Return(returnArg). - Once() -} - -func k8sMockClientDelete(mockClient *mocks.Client, returnArg interface{}) *mock.Call { - return mockClient.On("Delete", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - })). - Return(returnArg). - Once() -} - -func k8sMockClientList(mockClient *mocks.Client, returnArg interface{}) *mock.Call { - return mockClient.On("List", - mock.MatchedBy( - func(ctx context.Context) bool { - return ctx != nil - }), - mock.MatchedBy( - func(obj runtime.Object) bool { - return obj != nil - }), - mock.MatchedBy(matchListOptionsArg)). - Return(returnArg). - Once() -} - -func matchListOptionsArg(arg interface{}) bool { - return listOptionsFromArg(arg) != nil -} - -func listOptionsFromArg(arg interface{}) *client.ListOptions { - switch v := arg.(type) { - case *client.ListOptions: - return v - case []client.ListOption: - opts := &client.ListOptions{} - for _, opt := range v { - if opt != nil { - opt.ApplyToList(opts) - } - } - return opts - default: - return nil - } -} - func setupScheme(scheme *runtime.Scheme) *runtime.Scheme { if scheme == nil { scheme = runtime.NewScheme() } + _ = clientgoscheme.AddToScheme(scheme) _ = api.AddToScheme(scheme) + _ = taskapi.AddToScheme(scheme) _ = corev1.AddToScheme(scheme) _ = discoveryv1.AddToScheme(scheme) return scheme } + +func podPVCClaimNames(obj client.Object) []string { + pod, ok := obj.(*corev1.Pod) + if !ok { + return nil + } + + var claimNames []string + for _, volume := range pod.Spec.Volumes { + if volume.PersistentVolumeClaim != nil && volume.PersistentVolumeClaim.ClaimName != "" { + claimNames = append(claimNames, volume.PersistentVolumeClaim.ClaimName) + } + } + return claimNames +} From 98d569311e458dab34b5ff290e18ccff907cc91c Mon Sep 17 00:00:00 2001 From: Michael Burman Date: Mon, 13 Apr 2026 19:43:54 +0300 Subject: [PATCH 2/5] Replace httpClient mocks with httpserver --- pkg/reconciliation/decommission_node_test.go | 74 +-- pkg/reconciliation/reconcile_fql.go | 6 +- pkg/reconciliation/reconcile_fql_test.go | 249 +++----- pkg/reconciliation/reconcile_racks_test.go | 562 +++++++++---------- pkg/reconciliation/testing.go | 168 +++++- 5 files changed, 535 insertions(+), 524 deletions(-) diff --git a/pkg/reconciliation/decommission_node_test.go b/pkg/reconciliation/decommission_node_test.go index 7bae490f6..9cdb80dc5 100644 --- a/pkg/reconciliation/decommission_node_test.go +++ b/pkg/reconciliation/decommission_node_test.go @@ -4,17 +4,13 @@ package reconciliation import ( - "io" "net/http" - "strings" "sync" "testing" api "github.com/k8ssandra/cass-operator/apis/cassandra/v1beta1" "github.com/k8ssandra/cass-operator/internal/result" "github.com/k8ssandra/cass-operator/pkg/httphelper" - "github.com/k8ssandra/cass-operator/pkg/mocks" - "github.com/stretchr/testify/mock" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -24,58 +20,43 @@ func TestRetryDecommissionNode(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() state := "UP" - podIP := "192.168.101.11" rc.Datacenter.SetCondition(api.DatacenterCondition{ Status: corev1.ConditionTrue, Type: api.DatacenterScalingDown, }) - res := &http.Response{ - StatusCode: http.StatusBadRequest, - Body: io.NopCloser(strings.NewReader("OK")), - } wg := &sync.WaitGroup{} wg.Add(1) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/node/decommission" - })). - Return(res, nil). - Once(). - Run(func(args mock.Arguments) { wg.Done() }) - - resFeatureSet := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(strings.NewReader("")), - } - - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/metadata/versions/features" - })). - Return(resFeatureSet, nil). - Once() - - rc.NodeMgmtClient = httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.RequestURI() { + case "/api/v0/metadata/versions/features": + http.NotFound(w, r) + case "/api/v0/ops/node/decommission?force=true": + w.WriteHeader(http.StatusBadRequest) + wg.Done() + default: + http.NotFound(w, r) + } + })) + rc.NodeMgmtClient = server.client(rc.ReqLogger) labels := make(map[string]string) labels[api.CassNodeState] = stateDecommissioning - rc.dcPods = []*corev1.Pod{{ + pod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: "pod-1", Labels: labels, }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "cassandra", + }, + }, + }, Status: corev1.PodStatus{ - PodIP: podIP, ContainerStatuses: []corev1.ContainerStatus{ { Name: "cassandra", @@ -83,12 +64,14 @@ func TestRetryDecommissionNode(t *testing.T) { }, }, }, - }} + } + server.attachToPod(t, pod) + rc.dcPods = []*corev1.Pod{pod} epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{ { - RpcAddress: podIP, + RpcAddress: pod.Status.PodIP, Status: state, }, }, @@ -98,12 +81,13 @@ func TestRetryDecommissionNode(t *testing.T) { t.Fatalf("expected result of result.RequeueSoon(5) but got %s", r) } wg.Wait() + server.assertCallCount(t, "/api/v0/metadata/versions/features", 1) + server.assertCallCount(t, "/api/v0/ops/node/decommission", 1) } func TestRemoveResourcesWhenDone(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - podIP := "192.168.101.11" state := "LEFT" rc.Datacenter.SetCondition(api.DatacenterCondition{ @@ -119,9 +103,7 @@ func TestRemoveResourcesWhenDone(t *testing.T) { Name: "pod-1", Labels: labels, }, - Status: corev1.PodStatus{ - PodIP: podIP, - }, + Status: corev1.PodStatus{}, }} makeInt := func(i int32) *int32 { @@ -141,7 +123,7 @@ func TestRemoveResourcesWhenDone(t *testing.T) { epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{ { - RpcAddress: podIP, + RpcAddress: rc.dcPods[0].Status.PodIP, Status: state, }, }, diff --git a/pkg/reconciliation/reconcile_fql.go b/pkg/reconciliation/reconcile_fql.go index f6e20ad1e..6d7d012d1 100644 --- a/pkg/reconciliation/reconcile_fql.go +++ b/pkg/reconciliation/reconcile_fql.go @@ -20,11 +20,7 @@ func (rc *ReconciliationContext) CheckFullQueryLogging() result.ReconcileResult return result.Error(err) } - podList, err := rc.listPods(rc.Datacenter.GetClusterLabels()) - if err != nil { - rc.ReqLogger.Error(err, "error listing all pods in the cluster to progress full query logging reconciliation") - return result.RequeueSoon(2) - } + podList := rc.clusterPods for _, podPtr := range podList { features, err := rc.NodeMgmtClient.FeatureSet(podPtr) if err != nil { diff --git a/pkg/reconciliation/reconcile_fql_test.go b/pkg/reconciliation/reconcile_fql_test.go index cb4fb45e0..b0e1229ef 100644 --- a/pkg/reconciliation/reconcile_fql_test.go +++ b/pkg/reconciliation/reconcile_fql_test.go @@ -2,16 +2,12 @@ package reconciliation import ( "encoding/json" - "io" "net/http" - "strings" "testing" "github.com/k8ssandra/cass-operator/apis/cassandra/v1beta1" "github.com/k8ssandra/cass-operator/internal/result" - "github.com/k8ssandra/cass-operator/pkg/httphelper" - "github.com/k8ssandra/cass-operator/pkg/mocks" - "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client/fake" @@ -38,8 +34,7 @@ var ( httpResponseFullQueryDisabled string = `{"entity": false}` ) -func setupPodList(t *testing.T, rc *ReconciliationContext) { - t.Helper() +func setupPodList(rc *ReconciliationContext) { podIP := "192.168.101.11" pods := []corev1.Pod{{ @@ -48,11 +43,19 @@ func setupPodList(t *testing.T, rc *ReconciliationContext) { Namespace: rc.Datacenter.Namespace, Labels: rc.Datacenter.GetClusterLabels(), }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "cassandra", + }}, + }, Status: corev1.PodStatus{ PodIP: podIP, }, }} + rc.clusterPods = []*corev1.Pod{ + &pods[0], + } rc.dcPods = []*corev1.Pod{ &pods[0], } @@ -65,231 +68,161 @@ func setupPodList(t *testing.T, rc *ReconciliationContext) { Build() } -func mockFeaturesEnabled(mockHttpClient *mocks.HttpClient) { - resFeatureSet := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(fullQueryIsSupported)), - } - - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/metadata/versions/features" - })). - Return(resFeatureSet, nil). - Once() -} - -func mockFeaturesNotAvailable(mockHttpClient *mocks.HttpClient) { - resFeatureSet := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(strings.NewReader("")), - } - - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/metadata/versions/features" - })). - Return(resFeatureSet, nil). - Once() -} - -func mockFullQueryLoggingRequestToTrue(mockHttpClient *mocks.HttpClient) { - resFullQueryStatus := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(httpResponseFullQueryEnabled)), - } - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/node/fullquerylogging" - })). - Return(resFullQueryStatus, nil). - Once() -} - -func mockFullQueryLoggingRequestToFalse(mockHttpClient *mocks.HttpClient) { - resFullQueryStatus := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(httpResponseFullQueryDisabled)), - } - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/node/fullquerylogging" - })). - Return(resFullQueryStatus, nil). - Once() +func fqlFakeMgmtServer(t *testing.T, supportFQL bool, currentEnabled bool) *fakeMgmtApiServer { + return newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.RequestURI() { + case "/api/v0/metadata/versions/features": + if !supportFQL { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(fullQueryIsSupported)) + case "/api/v0/ops/node/fullquerylogging": + w.Header().Set("Content-Type", "application/json") + if currentEnabled { + _, _ = w.Write([]byte(httpResponseFullQueryEnabled)) + } else { + _, _ = w.Write([]byte(httpResponseFullQueryDisabled)) + } + case "/api/v0/ops/node/fullquerylogging?enabled=true", "/api/v0/ops/node/fullquerylogging?enabled=false": + w.WriteHeader(http.StatusOK) + default: + http.NotFound(w, r) + } + })) } -func setupTestEnv(t *testing.T) (*ReconciliationContext, func()) { +func setupTestEnv() (*ReconciliationContext, func()) { rc, _, cleanupMockScr := setupTest() - setupPodList(t, rc) + setupPodList(rc) rc.Datacenter.Spec.ServerType = "cassandra" rc.Datacenter.Spec.ServerVersion = "4.0.1" return rc, cleanupMockScr } +func attachPods(t *testing.T, rc *ReconciliationContext, server *fakeMgmtApiServer) { + server.attachToPod(t, rc.dcPods[0]) + server.attachToPod(t, rc.clusterPods[0]) + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter, rc.dcPods[0]). + WithRuntimeObjects(rc.Datacenter, rc.dcPods[0]). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() +} + func TestCheckFullQueryLoggingNoChangeEnabled(t *testing.T) { - rc, cleanupMockScr := setupTestEnv(t) + rc, cleanupMockScr := setupTestEnv() defer cleanupMockScr() - - mockHttpClient := mocks.NewHttpClient(t) - - // Mock features request to support FQL - mockFeaturesEnabled(mockHttpClient) - - // Mock fullQueryLogging to return true - mockFullQueryLoggingRequestToTrue(mockHttpClient) + server := fqlFakeMgmtServer(t, true, true) + for _, pod := range rc.dcPods { + server.attachToPod(t, pod) + } + rc.NodeMgmtClient = server.client(rc.ReqLogger) // Enable FQL config in the Datacenter rc.Datacenter.Spec.Config = json.RawMessage(fqlEnabledConfig) - - rc.NodeMgmtClient = httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } + attachPods(t, rc, server) + rc.NodeMgmtClient = server.client(rc.ReqLogger) r := rc.CheckFullQueryLogging() if r != result.Continue() { t.Fatalf("expected result of result.Continue() but got %s", r) } + server.assertCallCount(t, "/api/v0/metadata/versions/features", 1) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging", 1) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging?enabled=true", 0) } func TestCheckFullQueryLoggingNoChangeDisabled(t *testing.T) { - rc, cleanupMockScr := setupTestEnv(t) + rc, cleanupMockScr := setupTestEnv() defer cleanupMockScr() - - mockHttpClient := mocks.NewHttpClient(t) - - // Mock features request to support FQL - mockFeaturesEnabled(mockHttpClient) - - // Mock fullQueryLogging to return true - mockFullQueryLoggingRequestToFalse(mockHttpClient) + server := fqlFakeMgmtServer(t, true, false) // Don't enable FQL config in the Datacenter // rc.Datacenter.Spec.Config = json.RawMessage(fqlDisabledConfig) - - rc.NodeMgmtClient = httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } + attachPods(t, rc, server) + rc.NodeMgmtClient = server.client(rc.ReqLogger) r := rc.CheckFullQueryLogging() if r != result.Continue() { t.Fatalf("expected result of result.Continue() but got %s", r) } + server.assertCallCount(t, "/api/v0/metadata/versions/features", 1) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging", 1) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging?enabled=false", 0) } func TestCheckFullQueryNotSupported(t *testing.T) { - rc, cleanupMockScr := setupTestEnv(t) + rc, cleanupMockScr := setupTestEnv() defer cleanupMockScr() - - mockHttpClient := mocks.NewHttpClient(t) - - // Mock features request to not support FQL - mockFeaturesNotAvailable(mockHttpClient) - - rc.NodeMgmtClient = httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } + server := fqlFakeMgmtServer(t, false, false) + attachPods(t, rc, server) + rc.NodeMgmtClient = server.client(rc.ReqLogger) r := rc.CheckFullQueryLogging() if r != result.Continue() { t.Fatalf("expected result of result.Continue() but got %s", r) } + server.assertCallCount(t, "/api/v0/metadata/versions/features", 1) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging", 0) } func TestCheckFullQueryLoggingChangeToEnabled(t *testing.T) { - rc, cleanupMockScr := setupTestEnv(t) + rc, cleanupMockScr := setupTestEnv() defer cleanupMockScr() - - mockHttpClient := mocks.NewHttpClient(t) - - // Mock features request to support FQL - mockFeaturesEnabled(mockHttpClient) - - // Mock fullQueryLogging to return false - mockFullQueryLoggingRequestToFalse(mockHttpClient) - - // Mock fullQueryLogging change request - mockFullQueryLoggingRequestToTrue(mockHttpClient) + server := fqlFakeMgmtServer(t, true, false) // Enable FQL config in the Datacenter rc.Datacenter.Spec.Config = json.RawMessage(fqlEnabledConfig) - - rc.NodeMgmtClient = httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } + attachPods(t, rc, server) + rc.NodeMgmtClient = server.client(rc.ReqLogger) r := rc.CheckFullQueryLogging() if r != result.Continue() { t.Fatalf("expected result of result.Continue() but got %s", r) } + server.assertCallCount(t, "/api/v0/metadata/versions/features", 1) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging", 2) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging?enabled=true", 1) } func TestCheckFullQueryLoggingChangeToDisabled(t *testing.T) { - rc, cleanupMockScr := setupTestEnv(t) + rc, cleanupMockScr := setupTestEnv() defer cleanupMockScr() - - mockHttpClient := mocks.NewHttpClient(t) - - // Mock features request to support FQL - mockFeaturesEnabled(mockHttpClient) - - // Mock fullQueryLogging to return true - mockFullQueryLoggingRequestToTrue(mockHttpClient) - - // Mock fullQueryLogging change request to false - mockFullQueryLoggingRequestToFalse(mockHttpClient) + server := fqlFakeMgmtServer(t, true, true) // Keep FQL config disabled in the Datacenter - - rc.NodeMgmtClient = httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } + attachPods(t, rc, server) + rc.NodeMgmtClient = server.client(rc.ReqLogger) r := rc.CheckFullQueryLogging() if r != result.Continue() { t.Fatalf("expected result of result.Continue() but got %s", r) } + server.assertCallCount(t, "/api/v0/metadata/versions/features", 1) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging", 2) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging?enabled=false", 1) } func TestCheckFullQueryNotSupportedTriedToUse(t *testing.T) { - rc, cleanupMockScr := setupTestEnv(t) + rc, cleanupMockScr := setupTestEnv() defer cleanupMockScr() - - mockHttpClient := mocks.NewHttpClient(t) - - // Mock features request to not support FQL - mockFeaturesNotAvailable(mockHttpClient) + server := fqlFakeMgmtServer(t, false, false) // Enable FQL config in the Datacenter rc.Datacenter.Spec.Config = json.RawMessage(fqlEnabledConfig) - - rc.NodeMgmtClient = httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } + attachPods(t, rc, server) + rc.NodeMgmtClient = server.client(rc.ReqLogger) // The error is thrown in handler, but this test bypasses the validation - that's why we take Continue // as correct result. r := rc.CheckFullQueryLogging() _, err := r.Output() - if err == nil { - t.Fatalf("expected result of result.Error() but got %s", r) - } + require.Error(t, err) + server.assertCallCount(t, "/api/v0/metadata/versions/features", 1) + server.assertCallCount(t, "/api/v0/ops/node/fullquerylogging", 0) } func TestNotSupportedVersion(t *testing.T) { diff --git a/pkg/reconciliation/reconcile_racks_test.go b/pkg/reconciliation/reconcile_racks_test.go index 2ce18ab80..7e6c3e5ec 100644 --- a/pkg/reconciliation/reconcile_racks_test.go +++ b/pkg/reconciliation/reconcile_racks_test.go @@ -4,7 +4,6 @@ package reconciliation import ( - "bytes" "context" "fmt" "io" @@ -24,12 +23,10 @@ import ( taskapi "github.com/k8ssandra/cass-operator/apis/control/v1alpha1" "github.com/k8ssandra/cass-operator/internal/result" "github.com/k8ssandra/cass-operator/pkg/httphelper" - "github.com/k8ssandra/cass-operator/pkg/mocks" "github.com/k8ssandra/cass-operator/pkg/monitoring" "github.com/k8ssandra/cass-operator/pkg/oplabels" "github.com/k8ssandra/cass-operator/pkg/utils" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" @@ -236,12 +233,27 @@ func TestReconcileRacks_ReconcilePods(t *testing.T) { } mockPods := mockReadyPodsForStatefulSet(desiredStatefulSet, rc.Datacenter.Spec.ClusterName, rc.Datacenter.Name) + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v0/metadata/endpoints": + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"entity":[]}`) + case "/api/v0/ops/seeds/reload", "/api/v0/probes/cluster": + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "OK") + default: + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "OK") + } + })) for idx := range mockPods { mp := mockPods[idx] + server.attachToPod(t, mp) trackObjects = append(trackObjects, mp) } rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.NodeMgmtClient = server.client(rc.ReqLogger) nextRack := &RackInformation{} nextRack.RackName = desiredStatefulSet.Labels[api.RackLabel] @@ -1234,6 +1246,9 @@ func mockReadyPodsForStatefulSet(sts *appsv1.StatefulSet, cluster, dc string) [] Name: "cassandra", Ready: true, }} + pod.Spec.Containers = []corev1.Container{{ + Name: "cassandra", + }} pod.Status.PodIP = fmt.Sprintf("192.168.1.%d", i) pods = append(pods, pod) } @@ -1732,6 +1747,13 @@ func makeReloadTestPod() *corev1.Pod { api.DatacenterLabel: "mydc", }, }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "cassandra", + }, + }, + }, Status: corev1.PodStatus{ PodIP: "127.0.0.1", }, @@ -1743,84 +1765,57 @@ func Test_callPodEndpoint(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - res := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("OK")), - } + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() != "/api/v0/ops/seeds/reload" || r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "OK") + })) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(res, nil). - Once() - - client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } + client := server.client(rc.ReqLogger) pod := makeReloadTestPod() - pod.Status.PodIP = "1.2.3.4" + server.attachToPod(t, pod) if err := client.CallReloadSeedsEndpoint(pod); err != nil { assert.Fail(t, "Should not have returned error") } + server.assertCallCount(t, "/api/v0/ops/seeds/reload", 1) } func Test_callPodEndpoint_BadStatus(t *testing.T) { - res := &http.Response{ - StatusCode: http.StatusBadRequest, - Body: io.NopCloser(strings.NewReader("OK")), - } + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() != "/api/v0/ops/seeds/reload" || r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, "OK") + })) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/seeds/reload" && req.Method == "POST" - })). - Return(res, nil). - Once() - - client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: zap.New(), - Protocol: "http", - } + client := server.client(zap.New()) pod := makeReloadTestPod() + server.attachToPod(t, pod) if err := client.CallReloadSeedsEndpoint(pod); err == nil { assert.Fail(t, "Should have returned error") } + server.assertCallCount(t, "/api/v0/ops/seeds/reload", 1) } func Test_callPodEndpoint_RequestFail(t *testing.T) { - res := &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: io.NopCloser(strings.NewReader("OK")), - } + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(res, fmt.Errorf("")). - Once() - - client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: zap.New(), - Protocol: "http", - } + client := server.client(zap.New()) pod := makeReloadTestPod() + server.attachToPod(t, pod) + server.Close() if err := client.CallReloadSeedsEndpoint(pod); err == nil { assert.Fail(t, "Should have returned error") @@ -1913,24 +1908,17 @@ func TestStripPassword(t *testing.T) { defer cleanupMockScr() password := "secretPassword" - - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(nil, errors.New(password)). - Once() - + requestURIs := make([]string, 0, 1) client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, + Client: httpClientDoFunc(func(req *http.Request) (*http.Response, error) { + requestURIs = append(requestURIs, req.URL.RequestURI()) + return nil, errors.New(req.URL.RequestURI()) + }), Log: rc.ReqLogger, Protocol: "http", } pod := makeReloadTestPod() - pod.Status.PodIP = "1.2.3.4" err := client.CallCreateRoleEndpoint(pod, "userNameA", password, true) if err == nil { @@ -1938,6 +1926,8 @@ func TestStripPassword(t *testing.T) { } assert.False(t, strings.Contains(err.Error(), password)) + require.Len(t, requestURIs, 1) + assert.Contains(t, requestURIs[0], "/api/v0/ops/auth/role") } func TestNodereplacements(t *testing.T) { @@ -1997,33 +1987,22 @@ func TestFailedStart(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - res := &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: io.NopCloser(strings.NewReader("OK")), - } - - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(res, nil). - Once() - - nodeMgmtClient := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } - - rc.NodeMgmtClient = nodeMgmtClient + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() != "/api/v0/lifecycle/start" || r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, "OK") + })) + rc.NodeMgmtClient = server.client(rc.ReqLogger) epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, } pod := makeReloadTestPod() + server.attachToPod(t, pod) rc.Client = fake.NewClientBuilder(). WithScheme(setupScheme(nil)). WithStatusSubresource(rc.Datacenter). @@ -2050,6 +2029,7 @@ func TestFailedStart(t *testing.T) { close(fakeRecorder.Events) // Should have 2 events, one to indicate Cassandra is starting, one to indicate it failed to start assert.Equal(t, 2, len(fakeRecorder.Events)) + server.assertCallCount(t, "/api/v0/lifecycle/start", 1) } func TestStartBootstrappedNodes(t *testing.T) { @@ -2224,6 +2204,11 @@ func TestStartBootstrappedNodes(t *testing.T) { p := &corev1.Pod{} p.Name = getStatefulSetPodNameForIdx(sts, int32(i)) p.Labels = map[string]string{} + p.Spec.Containers = []corev1.Container{ + { + Name: "cassandra", + }, + } p.Status.ContainerStatuses = []corev1.ContainerStatus{ { Name: "cassandra", @@ -2258,13 +2243,6 @@ func TestStartBootstrappedNodes(t *testing.T) { trackObjects = append(trackObjects, pod) } - rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). - WithStatusSubresource(rc.Datacenter). - WithRuntimeObjects(trackObjects...). - WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). - Build() - expectedStartCount := 0 for i, rackPods := range tt.racks { for j, started := range rackPods { @@ -2283,29 +2261,29 @@ func TestStartBootstrappedNodes(t *testing.T) { }() if tt.wantNotReady { - res := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("OK")), - } - - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(res, nil). - Times(expectedStartCount). - Run(func(mock.Arguments) { wg.Done() }) - - client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() != "/api/v0/lifecycle/start" || r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "OK") + wg.Done() + })) + for _, pod := range rc.dcPods { + server.attachToPod(t, pod) } - rc.NodeMgmtClient = client + rc.NodeMgmtClient = server.client(rc.ReqLogger) + defer server.assertCallCount(t, "/api/v0/lifecycle/start", expectedStartCount) } + rc.Client = fake.NewClientBuilder(). + WithScheme(setupScheme(nil)). + WithStatusSubresource(rc.Datacenter). + WithRuntimeObjects(trackObjects...). + WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). + Build() + epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, } @@ -2575,6 +2553,11 @@ func TestReconciliationContext_startAllNodes(t *testing.T) { p := &corev1.Pod{} p.Name = getStatefulSetPodNameForIdx(sts, int32(i)) p.Labels = map[string]string{} + p.Spec.Containers = []corev1.Container{ + { + Name: "cassandra", + }, + } p.Status.ContainerStatuses = []corev1.ContainerStatus{ { Name: "cassandra", @@ -2596,6 +2579,23 @@ func TestReconciliationContext_startAllNodes(t *testing.T) { } } + done := make(chan struct{}) + if tt.wantNotReady { + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() != "/api/v0/lifecycle/start" || r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "OK") + close(done) + })) + for _, pod := range rc.dcPods { + server.attachToPod(t, pod) + } + rc.NodeMgmtClient = server.client(rc.ReqLogger) + defer server.assertCallCount(t, "/api/v0/lifecycle/start", 1) + } rc.Client = fake.NewClientBuilder(). WithScheme(setupScheme(nil)). WithStatusSubresource(rc.Datacenter). @@ -2603,31 +2603,6 @@ func TestReconciliationContext_startAllNodes(t *testing.T) { WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). Build() - done := make(chan struct{}) - if tt.wantNotReady { - res := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("OK")), - } - - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(res, nil). - Once(). - Run(func(mock.Arguments) { close(done) }) - - client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } - rc.NodeMgmtClient = client - } - epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, } @@ -2754,27 +2729,20 @@ func TestReconciliationContext_startAllNodes_onlyRackInformation(t *testing.T) { done := make(chan struct{}) if tt.wantNotReady { - res := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("OK")), - } - - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(res, nil). - Once(). - Run(func(mock.Arguments) { close(done) }) - - nodeMgmtClient := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() != "/api/v0/lifecycle/start" || r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "OK") + close(done) + })) + for _, pod := range rc.dcPods { + server.attachToPod(t, pod) } - rc.NodeMgmtClient = nodeMgmtClient + rc.NodeMgmtClient = server.client(rc.ReqLogger) + defer server.assertCallCount(t, "/api/v0/lifecycle/start", 1) } epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, @@ -2894,6 +2862,11 @@ func TestStartOneNodePerRack(t *testing.T) { p := &corev1.Pod{} p.Name = getStatefulSetPodNameForIdx(sts, int32(i)) p.Labels = map[string]string{} + p.Spec.Containers = []corev1.Container{ + { + Name: "cassandra", + }, + } readyToStart := true if tt.notReadyRacks[rackName] != nil && tt.notReadyRacks[rackName][i] { readyToStart = false @@ -2921,6 +2894,23 @@ func TestStartOneNodePerRack(t *testing.T) { } } + done := make(chan struct{}) + var server *fakeMgmtApiServer + if tt.wantNotReady { + server = newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() != "/api/v0/lifecycle/start" || r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "OK") + close(done) + })) + for _, pod := range rc.dcPods { + server.attachToPod(t, pod) + } + rc.NodeMgmtClient = server.client(rc.ReqLogger) + } rc.Client = fake.NewClientBuilder(). WithScheme(setupScheme(nil)). WithStatusSubresource(rc.Datacenter). @@ -2928,32 +2918,6 @@ func TestStartOneNodePerRack(t *testing.T) { WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). Build() - done := make(chan struct{}) - - if tt.wantNotReady { - res := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("OK")), - } - - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(res, nil). - Once(). - Run(func(args mock.Arguments) { close(done) }) - - client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } - rc.NodeMgmtClient = client - } - epData := httphelper.CassMetadataEndpoints{ Entity: []httphelper.EndpointState{}, } @@ -3659,8 +3623,18 @@ func TestCheckPodsReadyAllStarted(t *testing.T) { } mockPods := mockReadyPodsForStatefulSet(desiredStatefulSet, rc.Datacenter.Spec.ClusterName, rc.Datacenter.Name) + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v0/ops/seeds/reload", "/api/v0/probes/cluster": + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "OK") + default: + http.NotFound(w, r) + } + })) for idx := range mockPods { mp := mockPods[idx] + server.attachToPod(t, mp) metav1.SetMetaDataLabel(&mp.ObjectMeta, api.SeedNodeLabel, "true") trackObjects = append(trackObjects, mp) } @@ -3687,36 +3661,18 @@ func TestCheckPodsReadyAllStarted(t *testing.T) { for i := 0; i < int(*desiredStatefulSet.Spec.Replicas); i++ { ep := httphelper.EndpointState{ - RpcAddress: fmt.Sprintf("192.168.1.%d", i+1), + RpcAddress: mockPods[i].Status.PodIP, Status: "UN", } epData.Entity = append(epData.Entity, ep) } - res := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("OK")), - } - - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req != nil - })). - Return(res, nil). - Times(len(epData.Entity) * 2) // reloadSeeds * pods + clusterHealthCheck * pods - - client := httphelper.NodeMgmtClient{ - Client: mockHttpClient, - Log: rc.ReqLogger, - Protocol: "http", - } - - rc.NodeMgmtClient = client + rc.NodeMgmtClient = server.client(rc.ReqLogger) recRes := rc.CheckPodsReady(epData) assert.Equal(result.Continue(), recRes) // All pods should be up, no need to call anything + server.assertCallCount(t, "/api/v0/ops/seeds/reload", len(epData.Entity)) + server.assertCallCount(t, "/api/v0/probes/cluster", len(epData.Entity)) } func TestShouldUseFastPath(t *testing.T) { @@ -3748,6 +3704,13 @@ func TestUpdateCassandraNodeStatus_HostIDExtraction(t *testing.T) { api.CassNodeState: stateStarted, // Need this to pass the isMgmtApiRunning check }, }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "cassandra", + }, + }, + }, Status: corev1.PodStatus{ PodIP: "10.244.0.1", ContainerStatuses: []corev1.ContainerStatus{ @@ -3784,8 +3747,7 @@ func TestUpdateCassandraNodeStatus_HostIDExtraction(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mockHttpClient := mocks.NewHttpClient(t) - + pod := testPod.DeepCopy() featuresJson := []byte(`{ "cassandra_version": "4.0.0", "features": ["async_flush_task"] @@ -3797,8 +3759,26 @@ func TestUpdateCassandraNodeStatus_HostIDExtraction(t *testing.T) { }`) } - // Current mgmt-api - modernEndpointsJson := []byte(`{ + rc, _, cleanupMockScr := setupTest() + defer cleanupMockScr() + + rc.Datacenter.Status.NodeStatuses = map[string]api.CassandraNodeStatus{} + var endpointsJson []byte + + server := newFakeMgmtApiServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.RequestURI() { + case "/api/v0/metadata/versions/features": + w.WriteHeader(http.StatusOK) + _, _ = w.Write(featuresJson) + case "/api/v0/metadata/endpoints": + w.WriteHeader(http.StatusOK) + _, _ = w.Write(endpointsJson) + default: + http.NotFound(w, r) + } + })) + server.attachToPod(t, pod) + endpointsJson = []byte(`{ "entity": [ { "ENDPOINT_IP": "255.244.0.1", @@ -3810,48 +3790,29 @@ func TestUpdateCassandraNodeStatus_HostIDExtraction(t *testing.T) { } ] }`) - - // Date: Mon, 13 Apr 2026 20:04:38 +0300 Subject: [PATCH 3/5] Add missing httpClient removals from client_test.go (pretty straightforward modification from mock) --- pkg/httphelper/client_test.go | 128 +++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 56 deletions(-) diff --git a/pkg/httphelper/client_test.go b/pkg/httphelper/client_test.go index e59304ddc..86adc3975 100644 --- a/pkg/httphelper/client_test.go +++ b/pkg/httphelper/client_test.go @@ -16,8 +16,6 @@ import ( "testing" "github.com/go-logr/logr" - "github.com/k8ssandra/cass-operator/pkg/mocks" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/assert" @@ -226,7 +224,7 @@ func TestNodeMgmtClient_GetKeyspaceReplication(t *testing.T) { name string pod *corev1.Pod keyspaceName string - httpClient *mocks.HttpClient + httpClient HttpClient expected map[string]string err error }{ @@ -289,7 +287,7 @@ func TestNodeMgmtClient_ListTables(t *testing.T) { name string pod *corev1.Pod keyspaceName string - httpClient *mocks.HttpClient + httpClient HttpClient expected []string err error }{ @@ -359,7 +357,7 @@ func TestNodeMgmtClient_CreateTable(t *testing.T) { name string pod *corev1.Pod table *TableDefinition - httpClient *mocks.HttpClient + httpClient HttpClient err error }{ { @@ -461,16 +459,14 @@ func TestListRoles(t *testing.T) { require.NoError(err) require.Equal(3, len(roles)) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/auth/role" && req.Method == http.MethodGet - })). - Return(newHttpResponse(payload, http.StatusOK), nil). - Once() + httpClient := newAssertingHttpClient(t, func(req *http.Request) { + require.Equal(http.MethodGet, req.Method) + require.Equal("/api/v0/ops/auth/role", req.URL.Path) + }, func() *http.Response { + return newHttpResponse(payload, http.StatusOK) + }, nil) - mgmtClient := newMockMgmtClient(mockHttpClient) + mgmtClient := newMockMgmtClient(httpClient) roles, err = mgmtClient.CallListRolesEndpoint(goodPod) require.NoError(err) require.Equal(3, len(roles)) @@ -478,53 +474,51 @@ func TestListRoles(t *testing.T) { func TestCreateRole(t *testing.T) { require := require.New(t) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/auth/role" && req.Method == http.MethodPost && req.URL.Query().Get("username") == "role1" && req.URL.Query().Get("password") == "password1" && req.URL.Query().Get("is_superuser") == "true" - })). - Return(newHttpResponseMarshalled("OK", http.StatusOK), nil). - Once() - - mgmtClient := newMockMgmtClient(mockHttpClient) + httpClient := newAssertingHttpClient(t, func(req *http.Request) { + require.Equal(http.MethodPost, req.Method) + require.Equal("/api/v0/ops/auth/role", req.URL.Path) + require.Equal("role1", req.URL.Query().Get("username")) + require.Equal("password1", req.URL.Query().Get("password")) + require.Equal("true", req.URL.Query().Get("is_superuser")) + }, func() *http.Response { + return newHttpResponseMarshalled("OK", http.StatusOK) + }, nil) + + mgmtClient := newMockMgmtClient(httpClient) err := mgmtClient.CallCreateRoleEndpoint(goodPod, "role1", "password1", true) require.NoError(err) - require.True(mockHttpClient.AssertExpectations(t)) } func TestDropRole(t *testing.T) { require := require.New(t) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", - mock.MatchedBy( - func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/auth/role" && req.Method == http.MethodDelete - })). - Return(newHttpResponseMarshalled("OK", http.StatusOK), nil). - Once() - - mgmtClient := newMockMgmtClient(mockHttpClient) + httpClient := newAssertingHttpClient(t, func(req *http.Request) { + require.Equal(http.MethodDelete, req.Method) + require.Equal("/api/v0/ops/auth/role", req.URL.Path) + }, func() *http.Response { + return newHttpResponseMarshalled("OK", http.StatusOK) + }, nil) + + mgmtClient := newMockMgmtClient(httpClient) err := mgmtClient.CallDropRoleEndpoint(goodPod, "role1") require.NoError(err) - require.True(mockHttpClient.AssertExpectations(t)) } func TestCallDurationMetricSuccess(t *testing.T) { require := require.New(t) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/node/drain" && req.Method == http.MethodPost - })).Return(newHttpResponseMarshalled("OK", http.StatusOK), nil).Times(2) + httpClient := newAssertingHttpClient(t, func(req *http.Request) { + require.Equal(http.MethodPost, req.Method) + require.Equal("/api/v0/ops/node/drain", req.URL.Path) + }, func() *http.Response { + return newHttpResponseMarshalled("OK", http.StatusOK) + }, nil) before := getHttpHelperCallDurationCount(t, "/api/v0/ops/node/drain", resultSuccessLabelName) - mgmtClient := newMockMgmtClient(mockHttpClient) + mgmtClient := newMockMgmtClient(httpClient) require.NoError(mgmtClient.CallDrainEndpoint(goodPod)) require.NoError(mgmtClient.CallDrainEndpoint(goodPod)) - require.True(mockHttpClient.AssertExpectations(t)) after := getHttpHelperCallDurationCount(t, "/api/v0/ops/node/drain", resultSuccessLabelName) require.Equal(before+2, after) @@ -533,22 +527,29 @@ func TestCallDurationMetricSuccess(t *testing.T) { func TestCallDurationMetricError(t *testing.T) { require := require.New(t) - mockHttpClient := mocks.NewHttpClient(t) - mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { - return req.URL.Path == "/api/v0/ops/seeds/reload" && req.Method == http.MethodPost - })).Return(newHttpResponseMarshalled("this is an error", http.StatusInternalServerError), nil).Once() + httpClient := newAssertingHttpClient(t, func(req *http.Request) { + require.Equal(http.MethodPost, req.Method) + require.Equal("/api/v0/ops/seeds/reload", req.URL.Path) + }, func() *http.Response { + return newHttpResponseMarshalled("this is an error", http.StatusInternalServerError) + }, nil) before := getHttpHelperCallDurationCount(t, "/api/v0/ops/seeds/reload", resultErrorLabelName) - mgmtClient := newMockMgmtClient(mockHttpClient) + mgmtClient := newMockMgmtClient(httpClient) require.Error(mgmtClient.CallReloadSeedsEndpoint(goodPod)) - require.True(mockHttpClient.AssertExpectations(t)) after := getHttpHelperCallDurationCount(t, "/api/v0/ops/seeds/reload", resultErrorLabelName) require.Equal(before+1, after) } -func newMockMgmtClient(httpClient *mocks.HttpClient) *NodeMgmtClient { +type httpClientDoFunc func(*http.Request) (*http.Response, error) + +func (f httpClientDoFunc) Do(req *http.Request) (*http.Response, error) { + return f(req) +} + +func newMockMgmtClient(httpClient HttpClient) *NodeMgmtClient { return &NodeMgmtClient{ Client: httpClient, Log: logr.Discard(), @@ -556,10 +557,27 @@ func newMockMgmtClient(httpClient *mocks.HttpClient) *NodeMgmtClient { } } -func newMockHttpClient(response *http.Response, err error) *mocks.HttpClient { - httpClient := new(mocks.HttpClient) - httpClient.On("Do", mock.Anything).Return(response, err) - return httpClient +func newMockHttpClient(response *http.Response, err error) HttpClient { + return httpClientDoFunc(func(*http.Request) (*http.Response, error) { + return response, err + }) +} + +func newAssertingHttpClient( + t *testing.T, + assertRequest func(*http.Request), + response func() *http.Response, + err error, +) HttpClient { + t.Helper() + + return httpClientDoFunc(func(req *http.Request) (*http.Response, error) { + assertRequest(req) + if response == nil { + return nil, err + } + return response(), err + }) } func getHttpHelperCallDurationCount(t *testing.T, route, result string) uint64 { @@ -652,9 +670,7 @@ func TestCustomTransport(t *testing.T) { }, } - mockClient := mocks.NewClient(t) - - mgmtClient, err := NewMgmtClient(t.Context(), mockClient, dc, customTransport) + mgmtClient, err := NewMgmtClient(t.Context(), nil, dc, customTransport) mgmtClient.Log = logr.Discard() require.NoError(err) From 241599ddfca0ed2957fb43f369cdd429fa115e6f Mon Sep 17 00:00:00 2001 From: Michael Burman Date: Mon, 13 Apr 2026 20:05:11 +0300 Subject: [PATCH 4/5] Remove pkg/mocks entirely and docs associated with it --- docs/developer/mocks.md | 68 --- pkg/mocks/Client.go | 949 --------------------------------- pkg/mocks/HttpClient.go | 100 ---- pkg/mocks/SubResourceClient.go | 115 ---- pkg/mocks/helper.go | 40 -- 5 files changed, 1272 deletions(-) delete mode 100644 docs/developer/mocks.md delete mode 100644 pkg/mocks/Client.go delete mode 100644 pkg/mocks/HttpClient.go delete mode 100644 pkg/mocks/SubResourceClient.go delete mode 100644 pkg/mocks/helper.go diff --git a/docs/developer/mocks.md b/docs/developer/mocks.md deleted file mode 100644 index c3973b68c..000000000 --- a/docs/developer/mocks.md +++ /dev/null @@ -1,68 +0,0 @@ -To regenerate the controller-runtime client mocks we rely on `mockery` (v3.x at the moment). Because we wrap some of the generated types we cannot run `mockery` directly in `pkg/mocks`, so follow the steps below any time controller-runtime introduces a new method or signature. - -Install mockery: https://vektra.github.io/mockery/latest/installation/ - -## Create - -The helper interfaces live in `pkg/mockhelper/http.go`. If you ever need to recreate the file the contents should look like this: - -```go -package mockhelper - -import ( - "net/http" - - client "sigs.k8s.io/controller-runtime/pkg/client" -) - -type Client interface { - client.Client -} - -type HttpClient interface { - Do(req *http.Request) (*http.Response, error) -} -``` - -## Generate mocks - -Run mockery once per interface, writing the results to the helper package (we copy the files into `pkg/mocks` afterwards). From the repo root: - -```bash -mockery \ - --dir=pkg/mockhelper \ - --name=Client \ - --filename=Client.go \ - --structname=Client \ - --outpkg=mocks \ - --output=pkg/mockhelper \ - --with-expecter \ - --disable-version-string - -mockery \ - --dir=pkg/mockhelper \ - --name=HttpClient \ - --filename=HttpClient.go \ - --structname=HttpClient \ - --outpkg=mocks \ - --output=pkg/mockhelper \ - --with-expecter \ - --disable-version-string -``` - -Copy the resulting files from `pkg/mockhelper/*.go` to `pkg/mocks/`. - -## Customize the generated client - -Delete the temporary files from `pkg/mockhelper/` once the copies have been taken, and then adjust the generated controllers-runtime client mock in `pkg/mocks/Client.go`. - -Remove from controller-runtime client (`Client.go`): - -``` -// Client is an autogenerated mock type for the Client type -type Client struct { - mock.Mock -} -```` - -Comment out methods: `NewClient()`, `SubResource(subResource string)` and `Status()`. All of these are already separately created in `mocks/helper.go`, since we want to modify the behavior of certain Status checks in our test code. diff --git a/pkg/mocks/Client.go b/pkg/mocks/Client.go deleted file mode 100644 index 3010d5a27..000000000 --- a/pkg/mocks/Client.go +++ /dev/null @@ -1,949 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package mocks - -import ( - "context" - - mock "github.com/stretchr/testify/mock" - "k8s.io/apimachinery/pkg/api/meta" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -/* -func NewClient(t interface { - mock.TestingT - Cleanup(func()) -}) *Client { - mock := &Client{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} -*/ - -// Client is an autogenerated mock type for the Client type -/* -type Client struct { - mock.Mock -} -*/ - -type Client_Expecter struct { - mock *mock.Mock -} - -func (_m *Client) EXPECT() *Client_Expecter { - return &Client_Expecter{mock: &_m.Mock} -} - -// Apply provides a mock function for the type Client -func (_mock *Client) Apply(ctx context.Context, obj runtime.ApplyConfiguration, opts ...client.ApplyOption) error { - var tmpRet mock.Arguments - if len(opts) > 0 { - tmpRet = _mock.Called(ctx, obj, opts) - } else { - tmpRet = _mock.Called(ctx, obj) - } - ret := tmpRet - - if len(ret) == 0 { - panic("no return value specified for Apply") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, runtime.ApplyConfiguration, ...client.ApplyOption) error); ok { - r0 = returnFunc(ctx, obj, opts...) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Client_Apply_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Apply' -type Client_Apply_Call struct { - *mock.Call -} - -// Apply is a helper method to define mock.On call -// - ctx context.Context -// - obj runtime.ApplyConfiguration -// - opts ...client.ApplyOption -func (_e *Client_Expecter) Apply(ctx interface{}, obj interface{}, opts ...interface{}) *Client_Apply_Call { - return &Client_Apply_Call{Call: _e.mock.On("Apply", - append([]interface{}{ctx, obj}, opts...)...)} -} - -func (_c *Client_Apply_Call) Run(run func(ctx context.Context, obj runtime.ApplyConfiguration, opts ...client.ApplyOption)) *Client_Apply_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 runtime.ApplyConfiguration - if args[1] != nil { - arg1 = args[1].(runtime.ApplyConfiguration) - } - var arg2 []client.ApplyOption - var variadicArgs []client.ApplyOption - if len(args) > 2 { - variadicArgs = args[2].([]client.ApplyOption) - } - arg2 = variadicArgs - run( - arg0, - arg1, - arg2..., - ) - }) - return _c -} - -func (_c *Client_Apply_Call) Return(err error) *Client_Apply_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Client_Apply_Call) RunAndReturn(run func(ctx context.Context, obj runtime.ApplyConfiguration, opts ...client.ApplyOption) error) *Client_Apply_Call { - _c.Call.Return(run) - return _c -} - -// Create provides a mock function for the type Client -func (_mock *Client) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { - var tmpRet mock.Arguments - if len(opts) > 0 { - tmpRet = _mock.Called(ctx, obj, opts) - } else { - tmpRet = _mock.Called(ctx, obj) - } - ret := tmpRet - - if len(ret) == 0 { - panic("no return value specified for Create") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, client.Object, ...client.CreateOption) error); ok { - r0 = returnFunc(ctx, obj, opts...) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Client_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' -type Client_Create_Call struct { - *mock.Call -} - -// Create is a helper method to define mock.On call -// - ctx context.Context -// - obj client.Object -// - opts ...client.CreateOption -func (_e *Client_Expecter) Create(ctx interface{}, obj interface{}, opts ...interface{}) *Client_Create_Call { - return &Client_Create_Call{Call: _e.mock.On("Create", - append([]interface{}{ctx, obj}, opts...)...)} -} - -func (_c *Client_Create_Call) Run(run func(ctx context.Context, obj client.Object, opts ...client.CreateOption)) *Client_Create_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 client.Object - if args[1] != nil { - arg1 = args[1].(client.Object) - } - var arg2 []client.CreateOption - var variadicArgs []client.CreateOption - if len(args) > 2 { - variadicArgs = args[2].([]client.CreateOption) - } - arg2 = variadicArgs - run( - arg0, - arg1, - arg2..., - ) - }) - return _c -} - -func (_c *Client_Create_Call) Return(err error) *Client_Create_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Client_Create_Call) RunAndReturn(run func(ctx context.Context, obj client.Object, opts ...client.CreateOption) error) *Client_Create_Call { - _c.Call.Return(run) - return _c -} - -// Delete provides a mock function for the type Client -func (_mock *Client) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { - var tmpRet mock.Arguments - if len(opts) > 0 { - tmpRet = _mock.Called(ctx, obj, opts) - } else { - tmpRet = _mock.Called(ctx, obj) - } - ret := tmpRet - - if len(ret) == 0 { - panic("no return value specified for Delete") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, client.Object, ...client.DeleteOption) error); ok { - r0 = returnFunc(ctx, obj, opts...) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Client_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' -type Client_Delete_Call struct { - *mock.Call -} - -// Delete is a helper method to define mock.On call -// - ctx context.Context -// - obj client.Object -// - opts ...client.DeleteOption -func (_e *Client_Expecter) Delete(ctx interface{}, obj interface{}, opts ...interface{}) *Client_Delete_Call { - return &Client_Delete_Call{Call: _e.mock.On("Delete", - append([]interface{}{ctx, obj}, opts...)...)} -} - -func (_c *Client_Delete_Call) Run(run func(ctx context.Context, obj client.Object, opts ...client.DeleteOption)) *Client_Delete_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 client.Object - if args[1] != nil { - arg1 = args[1].(client.Object) - } - var arg2 []client.DeleteOption - var variadicArgs []client.DeleteOption - if len(args) > 2 { - variadicArgs = args[2].([]client.DeleteOption) - } - arg2 = variadicArgs - run( - arg0, - arg1, - arg2..., - ) - }) - return _c -} - -func (_c *Client_Delete_Call) Return(err error) *Client_Delete_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Client_Delete_Call) RunAndReturn(run func(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error) *Client_Delete_Call { - _c.Call.Return(run) - return _c -} - -// DeleteAllOf provides a mock function for the type Client -func (_mock *Client) DeleteAllOf(ctx context.Context, obj client.Object, opts ...client.DeleteAllOfOption) error { - var tmpRet mock.Arguments - if len(opts) > 0 { - tmpRet = _mock.Called(ctx, obj, opts) - } else { - tmpRet = _mock.Called(ctx, obj) - } - ret := tmpRet - - if len(ret) == 0 { - panic("no return value specified for DeleteAllOf") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, client.Object, ...client.DeleteAllOfOption) error); ok { - r0 = returnFunc(ctx, obj, opts...) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Client_DeleteAllOf_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteAllOf' -type Client_DeleteAllOf_Call struct { - *mock.Call -} - -// DeleteAllOf is a helper method to define mock.On call -// - ctx context.Context -// - obj client.Object -// - opts ...client.DeleteAllOfOption -func (_e *Client_Expecter) DeleteAllOf(ctx interface{}, obj interface{}, opts ...interface{}) *Client_DeleteAllOf_Call { - return &Client_DeleteAllOf_Call{Call: _e.mock.On("DeleteAllOf", - append([]interface{}{ctx, obj}, opts...)...)} -} - -func (_c *Client_DeleteAllOf_Call) Run(run func(ctx context.Context, obj client.Object, opts ...client.DeleteAllOfOption)) *Client_DeleteAllOf_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 client.Object - if args[1] != nil { - arg1 = args[1].(client.Object) - } - var arg2 []client.DeleteAllOfOption - var variadicArgs []client.DeleteAllOfOption - if len(args) > 2 { - variadicArgs = args[2].([]client.DeleteAllOfOption) - } - arg2 = variadicArgs - run( - arg0, - arg1, - arg2..., - ) - }) - return _c -} - -func (_c *Client_DeleteAllOf_Call) Return(err error) *Client_DeleteAllOf_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Client_DeleteAllOf_Call) RunAndReturn(run func(ctx context.Context, obj client.Object, opts ...client.DeleteAllOfOption) error) *Client_DeleteAllOf_Call { - _c.Call.Return(run) - return _c -} - -// Get provides a mock function for the type Client -func (_mock *Client) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { - var tmpRet mock.Arguments - if len(opts) > 0 { - tmpRet = _mock.Called(ctx, key, obj, opts) - } else { - tmpRet = _mock.Called(ctx, key, obj) - } - ret := tmpRet - - if len(ret) == 0 { - panic("no return value specified for Get") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, client.ObjectKey, client.Object, ...client.GetOption) error); ok { - r0 = returnFunc(ctx, key, obj, opts...) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Client_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' -type Client_Get_Call struct { - *mock.Call -} - -// Get is a helper method to define mock.On call -// - ctx context.Context -// - key client.ObjectKey -// - obj client.Object -// - opts ...client.GetOption -func (_e *Client_Expecter) Get(ctx interface{}, key interface{}, obj interface{}, opts ...interface{}) *Client_Get_Call { - return &Client_Get_Call{Call: _e.mock.On("Get", - append([]interface{}{ctx, key, obj}, opts...)...)} -} - -func (_c *Client_Get_Call) Run(run func(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption)) *Client_Get_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 client.ObjectKey - if args[1] != nil { - arg1 = args[1].(client.ObjectKey) - } - var arg2 client.Object - if args[2] != nil { - arg2 = args[2].(client.Object) - } - var arg3 []client.GetOption - var variadicArgs []client.GetOption - if len(args) > 3 { - variadicArgs = args[3].([]client.GetOption) - } - arg3 = variadicArgs - run( - arg0, - arg1, - arg2, - arg3..., - ) - }) - return _c -} - -func (_c *Client_Get_Call) Return(err error) *Client_Get_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Client_Get_Call) RunAndReturn(run func(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error) *Client_Get_Call { - _c.Call.Return(run) - return _c -} - -// GroupVersionKindFor provides a mock function for the type Client -func (_mock *Client) GroupVersionKindFor(obj runtime.Object) (schema.GroupVersionKind, error) { - ret := _mock.Called(obj) - - if len(ret) == 0 { - panic("no return value specified for GroupVersionKindFor") - } - - var r0 schema.GroupVersionKind - var r1 error - if returnFunc, ok := ret.Get(0).(func(runtime.Object) (schema.GroupVersionKind, error)); ok { - return returnFunc(obj) - } - if returnFunc, ok := ret.Get(0).(func(runtime.Object) schema.GroupVersionKind); ok { - r0 = returnFunc(obj) - } else { - r0 = ret.Get(0).(schema.GroupVersionKind) - } - if returnFunc, ok := ret.Get(1).(func(runtime.Object) error); ok { - r1 = returnFunc(obj) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// Client_GroupVersionKindFor_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GroupVersionKindFor' -type Client_GroupVersionKindFor_Call struct { - *mock.Call -} - -// GroupVersionKindFor is a helper method to define mock.On call -// - obj runtime.Object -func (_e *Client_Expecter) GroupVersionKindFor(obj interface{}) *Client_GroupVersionKindFor_Call { - return &Client_GroupVersionKindFor_Call{Call: _e.mock.On("GroupVersionKindFor", obj)} -} - -func (_c *Client_GroupVersionKindFor_Call) Run(run func(obj runtime.Object)) *Client_GroupVersionKindFor_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 runtime.Object - if args[0] != nil { - arg0 = args[0].(runtime.Object) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *Client_GroupVersionKindFor_Call) Return(groupVersionKind schema.GroupVersionKind, err error) *Client_GroupVersionKindFor_Call { - _c.Call.Return(groupVersionKind, err) - return _c -} - -func (_c *Client_GroupVersionKindFor_Call) RunAndReturn(run func(obj runtime.Object) (schema.GroupVersionKind, error)) *Client_GroupVersionKindFor_Call { - _c.Call.Return(run) - return _c -} - -// IsObjectNamespaced provides a mock function for the type Client -func (_mock *Client) IsObjectNamespaced(obj runtime.Object) (bool, error) { - ret := _mock.Called(obj) - - if len(ret) == 0 { - panic("no return value specified for IsObjectNamespaced") - } - - var r0 bool - var r1 error - if returnFunc, ok := ret.Get(0).(func(runtime.Object) (bool, error)); ok { - return returnFunc(obj) - } - if returnFunc, ok := ret.Get(0).(func(runtime.Object) bool); ok { - r0 = returnFunc(obj) - } else { - r0 = ret.Get(0).(bool) - } - if returnFunc, ok := ret.Get(1).(func(runtime.Object) error); ok { - r1 = returnFunc(obj) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// Client_IsObjectNamespaced_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsObjectNamespaced' -type Client_IsObjectNamespaced_Call struct { - *mock.Call -} - -// IsObjectNamespaced is a helper method to define mock.On call -// - obj runtime.Object -func (_e *Client_Expecter) IsObjectNamespaced(obj interface{}) *Client_IsObjectNamespaced_Call { - return &Client_IsObjectNamespaced_Call{Call: _e.mock.On("IsObjectNamespaced", obj)} -} - -func (_c *Client_IsObjectNamespaced_Call) Run(run func(obj runtime.Object)) *Client_IsObjectNamespaced_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 runtime.Object - if args[0] != nil { - arg0 = args[0].(runtime.Object) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *Client_IsObjectNamespaced_Call) Return(b bool, err error) *Client_IsObjectNamespaced_Call { - _c.Call.Return(b, err) - return _c -} - -func (_c *Client_IsObjectNamespaced_Call) RunAndReturn(run func(obj runtime.Object) (bool, error)) *Client_IsObjectNamespaced_Call { - _c.Call.Return(run) - return _c -} - -// List provides a mock function for the type Client -func (_mock *Client) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { - var tmpRet mock.Arguments - if len(opts) > 0 { - tmpRet = _mock.Called(ctx, list, opts) - } else { - tmpRet = _mock.Called(ctx, list) - } - ret := tmpRet - - if len(ret) == 0 { - panic("no return value specified for List") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, client.ObjectList, ...client.ListOption) error); ok { - r0 = returnFunc(ctx, list, opts...) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Client_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' -type Client_List_Call struct { - *mock.Call -} - -// List is a helper method to define mock.On call -// - ctx context.Context -// - list client.ObjectList -// - opts ...client.ListOption -func (_e *Client_Expecter) List(ctx interface{}, list interface{}, opts ...interface{}) *Client_List_Call { - return &Client_List_Call{Call: _e.mock.On("List", - append([]interface{}{ctx, list}, opts...)...)} -} - -func (_c *Client_List_Call) Run(run func(ctx context.Context, list client.ObjectList, opts ...client.ListOption)) *Client_List_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 client.ObjectList - if args[1] != nil { - arg1 = args[1].(client.ObjectList) - } - var arg2 []client.ListOption - var variadicArgs []client.ListOption - if len(args) > 2 { - variadicArgs = args[2].([]client.ListOption) - } - arg2 = variadicArgs - run( - arg0, - arg1, - arg2..., - ) - }) - return _c -} - -func (_c *Client_List_Call) Return(err error) *Client_List_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Client_List_Call) RunAndReturn(run func(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error) *Client_List_Call { - _c.Call.Return(run) - return _c -} - -// Patch provides a mock function for the type Client -func (_mock *Client) Patch(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption) error { - var tmpRet mock.Arguments - if len(opts) > 0 { - tmpRet = _mock.Called(ctx, obj, patch, opts) - } else { - tmpRet = _mock.Called(ctx, obj, patch) - } - ret := tmpRet - - if len(ret) == 0 { - panic("no return value specified for Patch") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, client.Object, client.Patch, ...client.PatchOption) error); ok { - r0 = returnFunc(ctx, obj, patch, opts...) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Client_Patch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Patch' -type Client_Patch_Call struct { - *mock.Call -} - -// Patch is a helper method to define mock.On call -// - ctx context.Context -// - obj client.Object -// - patch client.Patch -// - opts ...client.PatchOption -func (_e *Client_Expecter) Patch(ctx interface{}, obj interface{}, patch interface{}, opts ...interface{}) *Client_Patch_Call { - return &Client_Patch_Call{Call: _e.mock.On("Patch", - append([]interface{}{ctx, obj, patch}, opts...)...)} -} - -func (_c *Client_Patch_Call) Run(run func(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption)) *Client_Patch_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 client.Object - if args[1] != nil { - arg1 = args[1].(client.Object) - } - var arg2 client.Patch - if args[2] != nil { - arg2 = args[2].(client.Patch) - } - var arg3 []client.PatchOption - var variadicArgs []client.PatchOption - if len(args) > 3 { - variadicArgs = args[3].([]client.PatchOption) - } - arg3 = variadicArgs - run( - arg0, - arg1, - arg2, - arg3..., - ) - }) - return _c -} - -func (_c *Client_Patch_Call) Return(err error) *Client_Patch_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Client_Patch_Call) RunAndReturn(run func(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption) error) *Client_Patch_Call { - _c.Call.Return(run) - return _c -} - -// RESTMapper provides a mock function for the type Client -func (_mock *Client) RESTMapper() meta.RESTMapper { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for RESTMapper") - } - - var r0 meta.RESTMapper - if returnFunc, ok := ret.Get(0).(func() meta.RESTMapper); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(meta.RESTMapper) - } - } - return r0 -} - -// Client_RESTMapper_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RESTMapper' -type Client_RESTMapper_Call struct { - *mock.Call -} - -// RESTMapper is a helper method to define mock.On call -func (_e *Client_Expecter) RESTMapper() *Client_RESTMapper_Call { - return &Client_RESTMapper_Call{Call: _e.mock.On("RESTMapper")} -} - -func (_c *Client_RESTMapper_Call) Run(run func()) *Client_RESTMapper_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *Client_RESTMapper_Call) Return(rESTMapper meta.RESTMapper) *Client_RESTMapper_Call { - _c.Call.Return(rESTMapper) - return _c -} - -func (_c *Client_RESTMapper_Call) RunAndReturn(run func() meta.RESTMapper) *Client_RESTMapper_Call { - _c.Call.Return(run) - return _c -} - -// Scheme provides a mock function for the type Client -func (_mock *Client) Scheme() *runtime.Scheme { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for Scheme") - } - - var r0 *runtime.Scheme - if returnFunc, ok := ret.Get(0).(func() *runtime.Scheme); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*runtime.Scheme) - } - } - return r0 -} - -// Client_Scheme_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Scheme' -type Client_Scheme_Call struct { - *mock.Call -} - -// Scheme is a helper method to define mock.On call -func (_e *Client_Expecter) Scheme() *Client_Scheme_Call { - return &Client_Scheme_Call{Call: _e.mock.On("Scheme")} -} - -func (_c *Client_Scheme_Call) Run(run func()) *Client_Scheme_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *Client_Scheme_Call) Return(scheme *runtime.Scheme) *Client_Scheme_Call { - _c.Call.Return(scheme) - return _c -} - -func (_c *Client_Scheme_Call) RunAndReturn(run func() *runtime.Scheme) *Client_Scheme_Call { - _c.Call.Return(run) - return _c -} - -// Status provides a mock function for the type Client -/* -func (_mock *Client) Status() client.SubResourceWriter { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for Status") - } - - var r0 client.SubResourceWriter - if returnFunc, ok := ret.Get(0).(func() client.SubResourceWriter); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(client.SubResourceWriter) - } - } - return r0 -} - -// Client_Status_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Status' -type Client_Status_Call struct { - *mock.Call -} - -// Status is a helper method to define mock.On call -func (_e *Client_Expecter) Status() *Client_Status_Call { - return &Client_Status_Call{Call: _e.mock.On("Status")} -} - -func (_c *Client_Status_Call) Run(run func()) *Client_Status_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *Client_Status_Call) Return(subResourceWriter client.SubResourceWriter) *Client_Status_Call { - _c.Call.Return(subResourceWriter) - return _c -} - -func (_c *Client_Status_Call) RunAndReturn(run func() client.SubResourceWriter) *Client_Status_Call { - _c.Call.Return(run) - return _c -} -*/ - -// SubResource provides a mock function for the type Client -/* -func (_mock *Client) SubResource(subResource string) client.SubResourceClient { - ret := _mock.Called(subResource) - - if len(ret) == 0 { - panic("no return value specified for SubResource") - } - - var r0 client.SubResourceClient - if returnFunc, ok := ret.Get(0).(func(string) client.SubResourceClient); ok { - r0 = returnFunc(subResource) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(client.SubResourceClient) - } - } - return r0 -} - -// Client_SubResource_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SubResource' -type Client_SubResource_Call struct { - *mock.Call -} - -// SubResource is a helper method to define mock.On call -// - subResource string -func (_e *Client_Expecter) SubResource(subResource interface{}) *Client_SubResource_Call { - return &Client_SubResource_Call{Call: _e.mock.On("SubResource", subResource)} -} - -func (_c *Client_SubResource_Call) Run(run func(subResource string)) *Client_SubResource_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 string - if args[0] != nil { - arg0 = args[0].(string) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *Client_SubResource_Call) Return(subResourceClient client.SubResourceClient) *Client_SubResource_Call { - _c.Call.Return(subResourceClient) - return _c -} - -func (_c *Client_SubResource_Call) RunAndReturn(run func(subResource string) client.SubResourceClient) *Client_SubResource_Call { - _c.Call.Return(run) - return _c -} -*/ - -// Update provides a mock function for the type Client -func (_mock *Client) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { - var tmpRet mock.Arguments - if len(opts) > 0 { - tmpRet = _mock.Called(ctx, obj, opts) - } else { - tmpRet = _mock.Called(ctx, obj) - } - ret := tmpRet - - if len(ret) == 0 { - panic("no return value specified for Update") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, client.Object, ...client.UpdateOption) error); ok { - r0 = returnFunc(ctx, obj, opts...) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Client_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' -type Client_Update_Call struct { - *mock.Call -} - -// Update is a helper method to define mock.On call -// - ctx context.Context -// - obj client.Object -// - opts ...client.UpdateOption -func (_e *Client_Expecter) Update(ctx interface{}, obj interface{}, opts ...interface{}) *Client_Update_Call { - return &Client_Update_Call{Call: _e.mock.On("Update", - append([]interface{}{ctx, obj}, opts...)...)} -} - -func (_c *Client_Update_Call) Run(run func(ctx context.Context, obj client.Object, opts ...client.UpdateOption)) *Client_Update_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 client.Object - if args[1] != nil { - arg1 = args[1].(client.Object) - } - var arg2 []client.UpdateOption - var variadicArgs []client.UpdateOption - if len(args) > 2 { - variadicArgs = args[2].([]client.UpdateOption) - } - arg2 = variadicArgs - run( - arg0, - arg1, - arg2..., - ) - }) - return _c -} - -func (_c *Client_Update_Call) Return(err error) *Client_Update_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Client_Update_Call) RunAndReturn(run func(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error) *Client_Update_Call { - _c.Call.Return(run) - return _c -} diff --git a/pkg/mocks/HttpClient.go b/pkg/mocks/HttpClient.go deleted file mode 100644 index e66db00bf..000000000 --- a/pkg/mocks/HttpClient.go +++ /dev/null @@ -1,100 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package mocks - -import ( - "net/http" - - mock "github.com/stretchr/testify/mock" -) - -// NewHttpClient creates a new instance of HttpClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewHttpClient(t interface { - mock.TestingT - Cleanup(func()) -}) *HttpClient { - mock := &HttpClient{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// HttpClient is an autogenerated mock type for the HttpClient type -type HttpClient struct { - mock.Mock -} - -type HttpClient_Expecter struct { - mock *mock.Mock -} - -func (_m *HttpClient) EXPECT() *HttpClient_Expecter { - return &HttpClient_Expecter{mock: &_m.Mock} -} - -// Do provides a mock function for the type HttpClient -func (_mock *HttpClient) Do(req *http.Request) (*http.Response, error) { - ret := _mock.Called(req) - - if len(ret) == 0 { - panic("no return value specified for Do") - } - - var r0 *http.Response - var r1 error - if returnFunc, ok := ret.Get(0).(func(*http.Request) (*http.Response, error)); ok { - return returnFunc(req) - } - if returnFunc, ok := ret.Get(0).(func(*http.Request) *http.Response); ok { - r0 = returnFunc(req) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*http.Response) - } - } - if returnFunc, ok := ret.Get(1).(func(*http.Request) error); ok { - r1 = returnFunc(req) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// HttpClient_Do_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Do' -type HttpClient_Do_Call struct { - *mock.Call -} - -// Do is a helper method to define mock.On call -// - req *http.Request -func (_e *HttpClient_Expecter) Do(req interface{}) *HttpClient_Do_Call { - return &HttpClient_Do_Call{Call: _e.mock.On("Do", req)} -} - -func (_c *HttpClient_Do_Call) Run(run func(req *http.Request)) *HttpClient_Do_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 *http.Request - if args[0] != nil { - arg0 = args[0].(*http.Request) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *HttpClient_Do_Call) Return(response *http.Response, err error) *HttpClient_Do_Call { - _c.Call.Return(response, err) - return _c -} - -func (_c *HttpClient_Do_Call) RunAndReturn(run func(req *http.Request) (*http.Response, error)) *HttpClient_Do_Call { - _c.Call.Return(run) - return _c -} diff --git a/pkg/mocks/SubResourceClient.go b/pkg/mocks/SubResourceClient.go deleted file mode 100644 index 7e17bedc0..000000000 --- a/pkg/mocks/SubResourceClient.go +++ /dev/null @@ -1,115 +0,0 @@ -// Code generated by mockery v2.26.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - client "sigs.k8s.io/controller-runtime/pkg/client" - - mock "github.com/stretchr/testify/mock" -) - -// SubResourceClient is an autogenerated mock type for the SubResourceClient type -type SubResourceClient struct { - mock.Mock -} - -// Create provides a mock function with given fields: ctx, obj, subResource, opts -func (_m *SubResourceClient) Create(ctx context.Context, obj client.Object, subResource client.Object, opts ...client.SubResourceCreateOption) error { - _va := make([]interface{}, len(opts)) - for _i := range opts { - _va[_i] = opts[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, obj, subResource) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, client.Object, client.Object, ...client.SubResourceCreateOption) error); ok { - r0 = rf(ctx, obj, subResource, opts...) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Get provides a mock function with given fields: ctx, obj, subResource, opts -func (_m *SubResourceClient) Get(ctx context.Context, obj client.Object, subResource client.Object, opts ...client.SubResourceGetOption) error { - _va := make([]interface{}, len(opts)) - for _i := range opts { - _va[_i] = opts[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, obj, subResource) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, client.Object, client.Object, ...client.SubResourceGetOption) error); ok { - r0 = rf(ctx, obj, subResource, opts...) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Patch provides a mock function with given fields: ctx, obj, patch, opts -func (_m *SubResourceClient) Patch(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.SubResourcePatchOption) error { - _va := make([]interface{}, len(opts)) - for _i := range opts { - _va[_i] = opts[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, obj, patch) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, client.Object, client.Patch, ...client.SubResourcePatchOption) error); ok { - r0 = rf(ctx, obj, patch, opts...) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Update provides a mock function with given fields: ctx, obj, opts -func (_m *SubResourceClient) Update(ctx context.Context, obj client.Object, opts ...client.SubResourceUpdateOption) error { - _va := make([]interface{}, len(opts)) - for _i := range opts { - _va[_i] = opts[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, obj) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, client.Object, ...client.SubResourceUpdateOption) error); ok { - r0 = rf(ctx, obj, opts...) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -type mockConstructorTestingTNewSubResourceClient interface { - mock.TestingT - Cleanup(func()) -} - -// NewSubResourceClient creates a new instance of SubResourceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewSubResourceClient(t mockConstructorTestingTNewSubResourceClient) *SubResourceClient { - mock := &SubResourceClient{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/pkg/mocks/helper.go b/pkg/mocks/helper.go deleted file mode 100644 index dd056bbdd..000000000 --- a/pkg/mocks/helper.go +++ /dev/null @@ -1,40 +0,0 @@ -package mocks - -import ( - mock "github.com/stretchr/testify/mock" - client "sigs.k8s.io/controller-runtime/pkg/client" -) - -// This file overwrites some mocked methods - -// Client is an autogenerated mock type for the Client type -type Client struct { - mock.Mock - subResourceClient *SubResourceClient -} - -// Status provides a mock function with given fields. Overrides normal mock functionality (can't be asserted) -func (_m *Client) Status() client.SubResourceWriter { - return _m.SubResource("status") -} - -// SubResource provides a mock function with given fields: subResource. Overrides normal mock functionality (can't be asserted) -func (_m *Client) SubResource(subResource string) client.SubResourceClient { - return _m.subResourceClient -} - -type mockConstructorTestingTNewClient interface { - mock.TestingT - Cleanup(func()) -} - -// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewClient(t mockConstructorTestingTNewClient) *Client { - subclient := NewSubResourceClient(t) - mock := &Client{subResourceClient: subclient} - mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} From 882eb44daa428042e48ba6282986918d1d8424f7 Mon Sep 17 00:00:00 2001 From: Michael Burman Date: Mon, 13 Apr 2026 20:14:50 +0300 Subject: [PATCH 5/5] Refactor setupScheme to always create a new runtimeScheme --- pkg/httphelper/client_test.go | 22 +++---- pkg/reconciliation/handler_test.go | 10 ++-- .../reconcile_datacenter_test.go | 10 ++-- pkg/reconciliation/reconcile_fql_test.go | 4 +- pkg/reconciliation/reconcile_racks_test.go | 58 +++++++++---------- pkg/reconciliation/reconcile_services_test.go | 4 +- pkg/reconciliation/testing.go | 8 +-- 7 files changed, 53 insertions(+), 63 deletions(-) diff --git a/pkg/httphelper/client_test.go b/pkg/httphelper/client_test.go index 86adc3975..38a8c2077 100644 --- a/pkg/httphelper/client_test.go +++ b/pkg/httphelper/client_test.go @@ -464,7 +464,7 @@ func TestListRoles(t *testing.T) { require.Equal("/api/v0/ops/auth/role", req.URL.Path) }, func() *http.Response { return newHttpResponse(payload, http.StatusOK) - }, nil) + }) mgmtClient := newMockMgmtClient(httpClient) roles, err = mgmtClient.CallListRolesEndpoint(goodPod) @@ -482,7 +482,7 @@ func TestCreateRole(t *testing.T) { require.Equal("true", req.URL.Query().Get("is_superuser")) }, func() *http.Response { return newHttpResponseMarshalled("OK", http.StatusOK) - }, nil) + }) mgmtClient := newMockMgmtClient(httpClient) err := mgmtClient.CallCreateRoleEndpoint(goodPod, "role1", "password1", true) @@ -496,7 +496,7 @@ func TestDropRole(t *testing.T) { require.Equal("/api/v0/ops/auth/role", req.URL.Path) }, func() *http.Response { return newHttpResponseMarshalled("OK", http.StatusOK) - }, nil) + }) mgmtClient := newMockMgmtClient(httpClient) err := mgmtClient.CallDropRoleEndpoint(goodPod, "role1") @@ -512,7 +512,7 @@ func TestCallDurationMetricSuccess(t *testing.T) { require.Equal("/api/v0/ops/node/drain", req.URL.Path) }, func() *http.Response { return newHttpResponseMarshalled("OK", http.StatusOK) - }, nil) + }) before := getHttpHelperCallDurationCount(t, "/api/v0/ops/node/drain", resultSuccessLabelName) @@ -532,7 +532,7 @@ func TestCallDurationMetricError(t *testing.T) { require.Equal("/api/v0/ops/seeds/reload", req.URL.Path) }, func() *http.Response { return newHttpResponseMarshalled("this is an error", http.StatusInternalServerError) - }, nil) + }) before := getHttpHelperCallDurationCount(t, "/api/v0/ops/seeds/reload", resultErrorLabelName) @@ -563,20 +563,12 @@ func newMockHttpClient(response *http.Response, err error) HttpClient { }) } -func newAssertingHttpClient( - t *testing.T, - assertRequest func(*http.Request), - response func() *http.Response, - err error, -) HttpClient { +func newAssertingHttpClient(t *testing.T, assertRequest func(*http.Request), response func() *http.Response) HttpClient { t.Helper() return httpClientDoFunc(func(req *http.Request) (*http.Response, error) { assertRequest(req) - if response == nil { - return nil, err - } - return response(), err + return response(), nil }) } diff --git a/pkg/reconciliation/handler_test.go b/pkg/reconciliation/handler_test.go index 35123a86b..9b0b9b37a 100644 --- a/pkg/reconciliation/handler_test.go +++ b/pkg/reconciliation/handler_test.go @@ -56,7 +56,7 @@ func TestCalculateReconciliationActions(t *testing.T) { pod, } - fakeClient := fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter, service).WithRuntimeObjects(trackObjects...).Build() + fakeClient := fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(rc.Datacenter, service).WithRuntimeObjects(trackObjects...).Build() rc.Client = fakeClient result, err := rc.CalculateReconciliationActions() @@ -74,7 +74,7 @@ func TestCalculateReconciliationActions_GetServiceError(t *testing.T) { getErr := fmt.Errorf("") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -98,7 +98,7 @@ func TestCalculateReconciliationActions_FailedUpdate(t *testing.T) { updateErr := fmt.Errorf("failed to update CassandraDatacenter with removed finalizers") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -146,7 +146,7 @@ func TestProcessDeletion_FailedDelete(t *testing.T) { pvc := pvcProto(rc) deleteErr := fmt.Errorf("failed to delete pvc") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter, sts, pvc). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -184,7 +184,7 @@ func TestProcessDeletion(t *testing.T) { assert.NoError(err) pvc := pvcProto(rc) rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter, sts, pvc). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). diff --git a/pkg/reconciliation/reconcile_datacenter_test.go b/pkg/reconciliation/reconcile_datacenter_test.go index ce4311882..8bee8786d 100644 --- a/pkg/reconciliation/reconcile_datacenter_test.go +++ b/pkg/reconciliation/reconcile_datacenter_test.go @@ -30,7 +30,7 @@ func TestDeletePVCs(t *testing.T) { pvc := pvcProto(rc) rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter, pvc). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -48,7 +48,7 @@ func TestDeletePVCs_FailedToList(t *testing.T) { listErr := fmt.Errorf("failed to list PVCs for CassandraDatacenter") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -75,7 +75,7 @@ func TestDeletePVCs_PVCsNotFound(t *testing.T) { notFoundErr := errors.NewNotFound(schema.GroupResource{}, "name") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -99,7 +99,7 @@ func TestDeletePVCs_FailedToDelete(t *testing.T) { pvc := pvcProto(rc) deleteErr := fmt.Errorf("failed to delete") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter, pvc). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -129,7 +129,7 @@ func TestDeletePVCs_FailedToDeleteBeingUsed(t *testing.T) { pvc := pvcProto(rc) pod := podWithPVC(rc, "pod-1", pvc.Name) rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter, pvc, pod). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). diff --git a/pkg/reconciliation/reconcile_fql_test.go b/pkg/reconciliation/reconcile_fql_test.go index b0e1229ef..b6ec55bfd 100644 --- a/pkg/reconciliation/reconcile_fql_test.go +++ b/pkg/reconciliation/reconcile_fql_test.go @@ -61,7 +61,7 @@ func setupPodList(rc *ReconciliationContext) { } rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter, &pods[0]). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -105,7 +105,7 @@ func attachPods(t *testing.T, rc *ReconciliationContext, server *fakeMgmtApiServ server.attachToPod(t, rc.dcPods[0]) server.attachToPod(t, rc.clusterPods[0]) rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter, rc.dcPods[0]). WithRuntimeObjects(rc.Datacenter, rc.dcPods[0]). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). diff --git a/pkg/reconciliation/reconcile_racks_test.go b/pkg/reconciliation/reconcile_racks_test.go index 7e6c3e5ec..645c1d24f 100644 --- a/pkg/reconciliation/reconcile_racks_test.go +++ b/pkg/reconciliation/reconcile_racks_test.go @@ -252,7 +252,7 @@ func TestReconcileRacks_ReconcilePods(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() rc.NodeMgmtClient = server.client(rc.ReqLogger) nextRack := &RackInformation{} @@ -602,7 +602,7 @@ func TestReconcilePods_WithVolumes(t *testing.T) { pvc, } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(pod, pvc).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(pod, pvc).WithRuntimeObjects(trackObjects...).Build() err = rc.ReconcilePods(statefulSet) assert.NoErrorf(t, err, "Should not have returned an error") } @@ -648,7 +648,7 @@ func TestReconcileNextRack_CreateError(t *testing.T) { assert.NoErrorf(t, err, "error occurred creating statefulset") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -742,7 +742,7 @@ func TestReconcileRacks(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(desiredStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -767,7 +767,7 @@ func TestReconcileRacks_GetStatefulsetError(t *testing.T) { defer cleanupMockScr() rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -823,7 +823,7 @@ func TestReconcileRacks_WaitingForReplicas(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(desiredStatefulSet).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -860,7 +860,7 @@ func TestReconcileRacks_NeedMoreReplicas(t *testing.T) { preExistingStatefulSet, } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(preExistingStatefulSet).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(preExistingStatefulSet).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -902,7 +902,7 @@ func TestReconcileRacks_DoesntScaleDown(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(preExistingStatefulSet).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(preExistingStatefulSet).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -939,7 +939,7 @@ func TestReconcileRacks_NeedToPark(t *testing.T) { rc.Datacenter, } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(preExistingStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(preExistingStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -990,7 +990,7 @@ func TestReconcileRacks_AlreadyReconciled(t *testing.T) { desiredPdb, } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet, rc.Datacenter, desiredPdb).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(desiredStatefulSet, rc.Datacenter, desiredPdb).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -1067,7 +1067,7 @@ func TestReconcileRacks_FirstRackAlreadyReconciled(t *testing.T) { rc.Datacenter, } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet, secondDesiredStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(desiredStatefulSet, secondDesiredStatefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -1137,7 +1137,7 @@ func TestReconcileRacks_UpdateRackNodeCount(t *testing.T) { rc.Datacenter, } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(tt.args.statefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(tt.args.statefulSet, rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() if err := rc.UpdateRackNodeCount(tt.args.statefulSet, tt.args.newNodeCount); (err != nil) != tt.wantErr { t.Errorf("updateRackNodeCount() error = %v, wantErr %v", err, tt.wantErr) @@ -1179,7 +1179,7 @@ func TestReconcileRacks_UpdateConfig(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(desiredStatefulSet, rc.Datacenter, desiredPdb).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(desiredStatefulSet, rc.Datacenter, desiredPdb).WithRuntimeObjects(trackObjects...).Build() var rackInfo []*RackInformation @@ -1828,7 +1828,7 @@ func TestCleanupAfterScaling(t *testing.T) { assert := assert.New(t) rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -1854,7 +1854,7 @@ func TestCleanupAfterScalingWithTracker(t *testing.T) { metav1.SetMetaDataAnnotation(&rc.Datacenter.ObjectMeta, api.TrackCleanupTasksAnnotation, "true") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -1886,7 +1886,7 @@ func TestCleanupAfterScalingWithParallelAnnotation(t *testing.T) { _ = rc.CalculateRackInformation() metav1.SetMetaDataAnnotation(&rc.Datacenter.ObjectMeta, api.EnableParallelCleanupWithinRackAnnotation, "true") rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -2004,7 +2004,7 @@ func TestFailedStart(t *testing.T) { pod := makeReloadTestPod() server.attachToPod(t, pod) rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(runtimeObjectHelper(rc, nil, []*corev1.Pod{pod})...). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -2278,7 +2278,7 @@ func TestStartBootstrappedNodes(t *testing.T) { } rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(trackObjects...). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -2597,7 +2597,7 @@ func TestReconciliationContext_startAllNodes(t *testing.T) { defer server.assertCallCount(t, "/api/v0/lifecycle/start", 1) } rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.dcPods)...). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -2721,7 +2721,7 @@ func TestReconciliationContext_startAllNodes_onlyRackInformation(t *testing.T) { } } rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.dcPods)...). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -2912,7 +2912,7 @@ func TestStartOneNodePerRack(t *testing.T) { rc.NodeMgmtClient = server.client(rc.ReqLogger) } rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.dcPods)...). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -3442,7 +3442,7 @@ func TestDatacenterPods(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = "default" @@ -3487,7 +3487,7 @@ func TestDatacenterPodsOldLabels(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = "default" @@ -3541,7 +3541,7 @@ func TestDatacenterPodsNoDualFetch(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = "default" @@ -3587,7 +3587,7 @@ func TestCheckRackLabels(t *testing.T) { desiredStatefulSet, rc.Datacenter, } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() rc.statefulSets = []*appsv1.StatefulSet{desiredStatefulSet} @@ -3639,7 +3639,7 @@ func TestCheckPodsReadyAllStarted(t *testing.T) { trackObjects = append(trackObjects, mp) } - rc.Client = fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() + rc.Client = fake.NewClientBuilder().WithScheme(setupScheme()).WithStatusSubresource(rc.Datacenter).WithRuntimeObjects(trackObjects...).Build() nextRack := &RackInformation{} nextRack.RackName = desiredStatefulSet.Labels[api.RackLabel] @@ -4152,7 +4152,7 @@ func TestRefreshSeeds(t *testing.T) { epData.Entity[i].RpcAddress = pod.Status.PodIP } rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.clusterPods)...). Build() @@ -4181,7 +4181,7 @@ func TestRefreshSeeds(t *testing.T) { epData.Entity[i].RpcAddress = pod.Status.PodIP } rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.clusterPods)...). Build() @@ -4211,7 +4211,7 @@ func TestRefreshSeeds(t *testing.T) { pod.Status.PodIP = "127.0.0.1" } rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(runtimeObjectHelper(rc, rc.statefulSets, rc.clusterPods)...). Build() diff --git a/pkg/reconciliation/reconcile_services_test.go b/pkg/reconciliation/reconcile_services_test.go index 10413538c..91d2688e7 100644 --- a/pkg/reconciliation/reconcile_services_test.go +++ b/pkg/reconciliation/reconcile_services_test.go @@ -100,7 +100,7 @@ func TestCreateHeadlessService_ClientReturnsError(t *testing.T) { defer cleanupMockScr() rc.Client = fake.NewClientBuilder(). - WithScheme(setupScheme(nil)). + WithScheme(setupScheme()). WithStatusSubresource(rc.Datacenter). WithRuntimeObjects(rc.Datacenter). WithIndex(&corev1.Pod{}, podPVCClaimNameField, podPVCClaimNames). @@ -127,7 +127,7 @@ func TestEndpointSliceControllerIntegration(t *testing.T) { rc, _, cleanupMockScr := setupTest() defer cleanupMockScr() - fakeClient := fake.NewClientBuilder().WithScheme(setupScheme(nil)).WithRuntimeObjects(rc.Datacenter).Build() + fakeClient := fake.NewClientBuilder().WithScheme(setupScheme()).WithRuntimeObjects(rc.Datacenter).Build() rc.Client = fakeClient rc.Datacenter.Spec.AdditionalSeeds = []string{ diff --git a/pkg/reconciliation/testing.go b/pkg/reconciliation/testing.go index c13bedee1..8fe7ac91b 100644 --- a/pkg/reconciliation/testing.go +++ b/pkg/reconciliation/testing.go @@ -150,7 +150,7 @@ func CreateMockReconciliationContext( storageClass, } - s := setupScheme(runtime.NewScheme()) + s := setupScheme() fakeClient := fake.NewClientBuilder(). WithScheme(s). WithStatusSubresource(cassandraDatacenter). @@ -192,10 +192,8 @@ func setupTest() (*ReconciliationContext, *corev1.Service, func()) { return rc, service, cleanupMockScr } -func setupScheme(scheme *runtime.Scheme) *runtime.Scheme { - if scheme == nil { - scheme = runtime.NewScheme() - } +func setupScheme() *runtime.Scheme { + scheme := runtime.NewScheme() _ = clientgoscheme.AddToScheme(scheme) _ = api.AddToScheme(scheme) _ = taskapi.AddToScheme(scheme)