diff --git a/pkg/api/service/testing/make.go b/pkg/api/service/testing/make.go index 45bb1ef8db4..3709e784598 100644 --- a/pkg/api/service/testing/make.go +++ b/pkg/api/service/testing/make.go @@ -155,6 +155,13 @@ func SetInternalTrafficPolicy(policy api.ServiceInternalTrafficPolicyType) Tweak } } +// SetExternalTrafficPolicy sets the externalTrafficPolicy field for a Service. +func SetExternalTrafficPolicy(policy api.ServiceExternalTrafficPolicyType) Tweak { + return func(svc *api.Service) { + svc.Spec.ExternalTrafficPolicy = policy + } +} + // SetAllocateLoadBalancerNodePorts sets the allocate LB node port field. func SetAllocateLoadBalancerNodePorts(val bool) Tweak { return func(svc *api.Service) { diff --git a/pkg/registry/core/service/storage/rest_test.go b/pkg/registry/core/service/storage/rest_test.go index 9679026325e..66ad616982b 100644 --- a/pkg/registry/core/service/storage/rest_test.go +++ b/pkg/registry/core/service/storage/rest_test.go @@ -621,6 +621,99 @@ func TestServiceRegistryUpdate(t *testing.T) { } } +func TestServiceRegistryUpdateUnspecifiedAllocations(t *testing.T) { + testCases := []struct { + name string + svc *api.Service // Need a clusterIP, NodePort, and HealthCheckNodePort allocated + tweak func(*api.Service) + }{{ + name: "single-port", + svc: svctest.MakeService("foo", + svctest.SetTypeLoadBalancer, + svctest.SetExternalTrafficPolicy(api.ServiceExternalTrafficPolicyTypeLocal)), + tweak: nil, + }, { + name: "multi-port", + svc: svctest.MakeService("foo", + svctest.SetTypeLoadBalancer, + svctest.SetExternalTrafficPolicy(api.ServiceExternalTrafficPolicyTypeLocal), + svctest.SetPorts( + svctest.MakeServicePort("p", 80, intstr.FromInt(80), api.ProtocolTCP), + svctest.MakeServicePort("q", 443, intstr.FromInt(443), api.ProtocolTCP))), + tweak: nil, + }, { + name: "shuffle-ports", + svc: svctest.MakeService("foo", + svctest.SetTypeLoadBalancer, + svctest.SetExternalTrafficPolicy(api.ServiceExternalTrafficPolicyTypeLocal), + svctest.SetPorts( + svctest.MakeServicePort("p", 80, intstr.FromInt(80), api.ProtocolTCP), + svctest.MakeServicePort("q", 443, intstr.FromInt(443), api.ProtocolTCP))), + tweak: func(s *api.Service) { + s.Spec.Ports[0], s.Spec.Ports[1] = s.Spec.Ports[1], s.Spec.Ports[0] + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := genericapirequest.NewDefaultContext() + storage, server := NewTestREST(t, []api.IPFamily{api.IPv4Protocol}) + defer server.Terminate(t) + + svc := tc.svc.DeepCopy() + obj, err := storage.Create(ctx, svc.DeepCopy(), rest.ValidateAllObjectFunc, &metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Expected no error: %v", err) + } + createdSvc := obj.(*api.Service) + if createdSvc.Spec.ClusterIP == "" { + t.Fatalf("expected ClusterIP to be set") + } + if len(createdSvc.Spec.ClusterIPs) == 0 { + t.Fatalf("expected ClusterIPs to be set") + } + for i := range createdSvc.Spec.Ports { + if createdSvc.Spec.Ports[i].NodePort == 0 { + t.Fatalf("expected NodePort[%d] to be set", i) + } + } + if createdSvc.Spec.HealthCheckNodePort == 0 { + t.Fatalf("expected HealthCheckNodePort to be set") + } + + // Update from the original object - just change the selector. + svc.Spec.Selector = map[string]string{"bar": "baz2"} + svc.ResourceVersion = createdSvc.ResourceVersion + + obj, _, err = storage.Update(ctx, svc.Name, rest.DefaultUpdatedObjectInfo(svc.DeepCopy()), rest.ValidateAllObjectFunc, rest.ValidateAllObjectUpdateFunc, false, &metav1.UpdateOptions{}) + if err != nil { + t.Fatalf("Expected no error: %v", err) + } + updatedSvc := obj.(*api.Service) + + if want, got := createdSvc.Spec.ClusterIP, updatedSvc.Spec.ClusterIP; want != got { + t.Errorf("expected ClusterIP to not change: wanted %v, got %v", want, got) + } + if want, got := createdSvc.Spec.ClusterIPs, updatedSvc.Spec.ClusterIPs; !reflect.DeepEqual(want, got) { + t.Errorf("expected ClusterIPs to not change: wanted %v, got %v", want, got) + } + portmap := func(s *api.Service) map[string]int32 { + ret := map[string]int32{} + for _, p := range s.Spec.Ports { + ret[p.Name] = p.NodePort + } + return ret + } + if want, got := portmap(createdSvc), portmap(updatedSvc); !reflect.DeepEqual(want, got) { + t.Errorf("expected NodePort to not change: wanted %v, got %v", want, got) + } + if want, got := createdSvc.Spec.HealthCheckNodePort, updatedSvc.Spec.HealthCheckNodePort; want != got { + t.Errorf("expected HealthCheckNodePort to not change: wanted %v, got %v", want, got) + } + }) + } +} + func TestServiceRegistryUpdateDryRun(t *testing.T) { ctx := genericapirequest.NewDefaultContext() storage, server := NewTestREST(t, []api.IPFamily{api.IPv4Protocol}) diff --git a/pkg/registry/core/service/strategy.go b/pkg/registry/core/service/strategy.go index 0291f2ce35a..6e3553acd38 100644 --- a/pkg/registry/core/service/strategy.go +++ b/pkg/registry/core/service/strategy.go @@ -119,6 +119,7 @@ func (strategy svcStrategy) PrepareForUpdate(ctx context.Context, obj, old runti oldService := old.(*api.Service) newService.Status = oldService.Status + patchAllocatedValues(newService, oldService) NormalizeClusterIPs(oldService, newService) dropServiceDisabledFields(newService, oldService) dropTypeDependentFields(newService, oldService) @@ -302,6 +303,43 @@ func (serviceStatusStrategy) WarningsOnUpdate(ctx context.Context, obj, old runt return nil } +// patchAllocatedValues allows clients to avoid a read-modify-write cycle while +// preserving values that we allocated on their behalf. For example, they +// might create a Service without specifying the ClusterIP, in which case we +// allocate one. If they resubmit that same YAML, we want it to succeed. +func patchAllocatedValues(newSvc, oldSvc *api.Service) { + if needsClusterIP(oldSvc) && needsClusterIP(newSvc) { + if newSvc.Spec.ClusterIP == "" { + newSvc.Spec.ClusterIP = oldSvc.Spec.ClusterIP + } + if len(newSvc.Spec.ClusterIPs) == 0 { + newSvc.Spec.ClusterIPs = oldSvc.Spec.ClusterIPs + } + } + + if needsNodePort(oldSvc) && needsNodePort(newSvc) { + // Map NodePorts by name. The user may have changed other properties + // of the port, but we won't see that here. + np := map[string]int32{} + for i := range oldSvc.Spec.Ports { + p := &oldSvc.Spec.Ports[i] + np[p.Name] = p.NodePort + } + for i := range newSvc.Spec.Ports { + p := &newSvc.Spec.Ports[i] + if p.NodePort == 0 { + p.NodePort = np[p.Name] + } + } + } + + if needsHCNodePort(oldSvc) && needsHCNodePort(newSvc) { + if newSvc.Spec.HealthCheckNodePort == 0 { + newSvc.Spec.HealthCheckNodePort = oldSvc.Spec.HealthCheckNodePort + } + } +} + // NormalizeClusterIPs adjust clusterIPs based on ClusterIP. This must not // consider any other fields. func NormalizeClusterIPs(oldSvc, newSvc *api.Service) {