diff --git a/pkg/registry/core/service/storage/storage_test.go b/pkg/registry/core/service/storage/storage_test.go index 5d72cd610ed..bb06ea575aa 100644 --- a/pkg/registry/core/service/storage/storage_test.go +++ b/pkg/registry/core/service/storage/storage_test.go @@ -17,6 +17,7 @@ limitations under the License. package storage import ( + "context" "fmt" "net" "reflect" @@ -87,11 +88,11 @@ func portIsAllocated(t *testing.T, alloc portallocator.Interface, port int32) bo return alloc.Has(int(port)) } -func newStorage(t *testing.T, ipFamilies []api.IPFamily) (*GenericREST, *StatusREST, *etcd3testing.EtcdTestServer) { +func newStorage(t *testing.T, ipFamilies []api.IPFamily) (*wrapperRESTForTests, *StatusREST, *etcd3testing.EtcdTestServer) { return newStorageWithPods(t, ipFamilies, nil, nil) } -func newStorageWithPods(t *testing.T, ipFamilies []api.IPFamily, pods []api.Pod, endpoints []*api.Endpoints) (*GenericREST, *StatusREST, *etcd3testing.EtcdTestServer) { +func newStorageWithPods(t *testing.T, ipFamilies []api.IPFamily, pods []api.Pod, endpoints []*api.Endpoints) (*wrapperRESTForTests, *StatusREST, *etcd3testing.EtcdTestServer) { etcdStorage, server := registrytest.NewEtcdStorage(t, "") restOptions := generic.RESTOptions{ StorageConfig: etcdStorage.ForResource(schema.GroupResource{Resource: "services"}), @@ -158,7 +159,21 @@ func newStorageWithPods(t *testing.T, ipFamilies []api.IPFamily, pods []api.Pod, if err != nil { t.Fatalf("unexpected error from REST storage: %v", err) } - return serviceStorage, statusStorage, server + return &wrapperRESTForTests{serviceStorage}, statusStorage, server +} + +// wrapperRESTForTests is a *trivial* wrapper for the real REST, which allows us to do +// things that are specifically to enhance test safety. +type wrapperRESTForTests struct { + *GenericREST +} + +func (f *wrapperRESTForTests) Create(ctx context.Context, obj runtime.Object, createValidation rest.ValidateObjectFunc, options *metav1.CreateOptions) (runtime.Object, error) { + // Making a DeepCopy here ensures that any in-place mutations of the input + // are not going to propagate to verification code, which used to happen + // resulting in tests that passed when they shouldn't have. + obj = obj.DeepCopyObject() + return f.GenericREST.Create(ctx, obj, createValidation, options) } // This is used in generic registry tests. @@ -6301,14 +6316,14 @@ func TestUpdatePatchAllocatedValues(t *testing.T) { return proofs } proveClusterIP := func(idx int, ip string) svcTestProof { - return func(t *testing.T, storage *GenericREST, before, after *api.Service) { + return func(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { if want, got := ip, after.Spec.ClusterIPs[idx]; want != got { t.Errorf("wrong ClusterIPs[%d]: want %q, got %q", idx, want, got) } } } proveNodePort := func(idx int, port int32) svcTestProof { - return func(t *testing.T, storage *GenericREST, before, after *api.Service) { + return func(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { got := after.Spec.Ports[idx].NodePort if port > 0 && got != port { t.Errorf("wrong Ports[%d].NodePort: want %d, got %d", idx, port, got) @@ -6318,7 +6333,7 @@ func TestUpdatePatchAllocatedValues(t *testing.T) { } } proveHCNP := func(port int32) svcTestProof { - return func(t *testing.T, storage *GenericREST, before, after *api.Service) { + return func(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { got := after.Spec.HealthCheckNodePort if port > 0 && got != port { t.Errorf("wrong HealthCheckNodePort: want %d, got %d", port, got) @@ -6568,7 +6583,7 @@ func TestUpdateIPsFromSingleStack(t *testing.T) { return proofs } proveNumFamilies := func(n int) svcTestProof { - return func(t *testing.T, storage *GenericREST, before, after *api.Service) { + return func(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if got := len(after.Spec.IPFamilies); got != n { t.Errorf("wrong number of ipFamilies: expected %d, got %d", n, got) @@ -7328,7 +7343,7 @@ func TestUpdateIPsFromSingleStack(t *testing.T) { expectClusterIPs: true, prove: prove(proveNumFamilies(1)), }, - beforeUpdate: func(t *testing.T, storage *GenericREST) { + beforeUpdate: func(t *testing.T, storage *wrapperRESTForTests) { alloc := storage.alloc.serviceIPAllocatorsByFamily[api.IPv6Protocol] ip := "2000::1" if err := alloc.Allocate(netutils.ParseIPSloppy(ip)); err != nil { @@ -8036,7 +8051,7 @@ func TestUpdateIPsFromSingleStack(t *testing.T) { expectClusterIPs: true, prove: prove(proveNumFamilies(1)), }, - beforeUpdate: func(t *testing.T, storage *GenericREST) { + beforeUpdate: func(t *testing.T, storage *wrapperRESTForTests) { alloc := storage.alloc.serviceIPAllocatorsByFamily[api.IPv4Protocol] ip := "10.0.0.1" if err := alloc.Allocate(netutils.ParseIPSloppy(ip)); err != nil { @@ -8280,7 +8295,7 @@ func TestUpdateIPsFromDualStack(t *testing.T) { return proofs } proveNumFamilies := func(n int) svcTestProof { - return func(t *testing.T, storage *GenericREST, before, after *api.Service) { + return func(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if got := len(after.Spec.IPFamilies); got != n { t.Errorf("wrong number of ipFamilies: expected %d, got %d", n, got) @@ -9889,7 +9904,7 @@ type svcTestCase struct { prove []svcTestProof } -type svcTestProof func(t *testing.T, storage *GenericREST, before, after *api.Service) +type svcTestProof func(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) func callName(before, after *api.Service) string { if before == nil && after != nil { @@ -9904,7 +9919,7 @@ func callName(before, after *api.Service) string { panic("this test is broken: before and after are both nil") } -func proveClusterIPsAllocated(t *testing.T, storage *GenericREST, before, after *api.Service) { +func proveClusterIPsAllocated(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if sing, plur := after.Spec.ClusterIP, after.Spec.ClusterIPs[0]; sing != plur { @@ -9980,7 +9995,7 @@ func proveClusterIPsAllocated(t *testing.T, storage *GenericREST, before, after } } -func proveClusterIPsDeallocated(t *testing.T, storage *GenericREST, before, after *api.Service) { +func proveClusterIPsDeallocated(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if after != nil && after.Spec.ClusterIP != api.ClusterIPNone { @@ -10007,7 +10022,7 @@ func proveClusterIPsDeallocated(t *testing.T, storage *GenericREST, before, afte } } -func proveHeadless(t *testing.T, storage *GenericREST, before, after *api.Service) { +func proveHeadless(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if sing, plur := after.Spec.ClusterIP, after.Spec.ClusterIPs[0]; sing != plur { @@ -10018,7 +10033,7 @@ func proveHeadless(t *testing.T, storage *GenericREST, before, after *api.Servic } } -func proveNodePortsAllocated(t *testing.T, storage *GenericREST, before, after *api.Service) { +func proveNodePortsAllocated(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() for _, p := range after.Spec.Ports { @@ -10028,7 +10043,7 @@ func proveNodePortsAllocated(t *testing.T, storage *GenericREST, before, after * } } -func proveNodePortsDeallocated(t *testing.T, storage *GenericREST, before, after *api.Service) { +func proveNodePortsDeallocated(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if after != nil { @@ -10048,7 +10063,7 @@ func proveNodePortsDeallocated(t *testing.T, storage *GenericREST, before, after } } -func proveHealthCheckNodePortAllocated(t *testing.T, storage *GenericREST, before, after *api.Service) { +func proveHealthCheckNodePortAllocated(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if !portIsAllocated(t, storage.alloc.serviceNodePorts, after.Spec.HealthCheckNodePort) { @@ -10056,7 +10071,7 @@ func proveHealthCheckNodePortAllocated(t *testing.T, storage *GenericREST, befor } } -func proveHealthCheckNodePortDeallocated(t *testing.T, storage *GenericREST, before, after *api.Service) { +func proveHealthCheckNodePortDeallocated(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if after != nil { @@ -10076,7 +10091,7 @@ type cudTestCase struct { name string line string // if not empty, will be logged with errors, use line() to set create svcTestCase - beforeUpdate func(t *testing.T, storage *GenericREST) + beforeUpdate func(t *testing.T, storage *wrapperRESTForTests) update svcTestCase } @@ -10380,7 +10395,7 @@ func TestVerifyEquiv(t *testing.T) { } } -func verifyExpectations(t *testing.T, storage *GenericREST, tc svcTestCase, before, after *api.Service) { +func verifyExpectations(t *testing.T, storage *wrapperRESTForTests, tc svcTestCase, before, after *api.Service) { t.Helper() if tc.expectClusterIPs { @@ -11341,7 +11356,7 @@ func TestFeatureInternalTrafficPolicy(t *testing.T) { return proofs } proveITP := func(want api.ServiceInternalTrafficPolicyType) svcTestProof { - return func(t *testing.T, storage *GenericREST, before, after *api.Service) { + return func(t *testing.T, storage *wrapperRESTForTests, before, after *api.Service) { t.Helper() if got := after.Spec.InternalTrafficPolicy; got == nil { if want != "" {