diff --git a/staging/src/k8s.io/legacy-cloud-providers/gce/gce_loadbalancer_internal_test.go b/staging/src/k8s.io/legacy-cloud-providers/gce/gce_loadbalancer_internal_test.go index f3b2d902ce3..be9560a22d6 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/gce/gce_loadbalancer_internal_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/gce/gce_loadbalancer_internal_test.go @@ -32,11 +32,11 @@ import ( "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/meta" "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/mock" "google.golang.org/api/compute/v1" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" - "k8s.io/cloud-provider" + cloudprovider "k8s.io/cloud-provider" servicehelper "k8s.io/cloud-provider/service/helpers" ) @@ -1638,3 +1638,64 @@ func TestEnsureLoadBalancerPartialDelete(t *testing.T) { require.NoError(t, err) assert.False(t, exists) } + +func TestEnsureInternalLoadBalancerModifyProtocol(t *testing.T) { + t.Parallel() + + vals := DefaultTestClusterValues() + gce, err := fakeGCECloud(vals) + require.NoError(t, err) + c := gce.c.(*cloud.MockGCE) + c.MockRegionBackendServices.UpdateHook = func(ctx context.Context, key *meta.Key, be *compute.BackendService, m *cloud.MockRegionBackendServices) error { + // Same key can be used since FR will have the same name. + fr, err := c.MockForwardingRules.Get(ctx, key) + if err != nil && !isNotFound(err) { + return err + } + if fr != nil && fr.IPProtocol != be.Protocol { + return fmt.Errorf("Protocol mismatch between Forwarding Rule value %q and Backend service value %q", fr.IPProtocol, be.Protocol) + } + return mock.UpdateRegionBackendServiceHook(ctx, key, be, m) + } + nodeNames := []string{"test-node-1"} + nodes, err := createAndInsertNodes(gce, nodeNames, vals.ZoneName) + require.NoError(t, err) + svc := fakeLoadbalancerService(string(LBTypeInternal)) + svc, err = gce.client.CoreV1().Services(svc.Namespace).Create(context.TODO(), svc, metav1.CreateOptions{}) + require.NoError(t, err) + lbName := gce.GetLoadBalancerName(context.TODO(), "", svc) + status, err := createInternalLoadBalancer(gce, svc, nil, nodeNames, vals.ClusterName, vals.ClusterID, vals.ZoneName) + if err != nil { + t.Errorf("Unexpected error %v", err) + } + assert.NotEmpty(t, status.Ingress) + fwdRule, err := gce.GetRegionForwardingRule(lbName, gce.region) + if err != nil { + t.Errorf("gce.GetRegionForwardingRule(%q, %q) = %v, want nil", lbName, gce.region, err) + } + if fwdRule.IPProtocol != "TCP" { + t.Errorf("Unexpected protocol value %s, expected TCP", fwdRule.IPProtocol) + } + + // change the protocol to UDP + svc.Spec.Ports[0].Protocol = v1.ProtocolUDP + status, err = gce.EnsureLoadBalancer(context.Background(), vals.ClusterName, svc, nodes) + if err != nil { + t.Errorf("Unexpected error %v", err) + } + assert.NotEmpty(t, status.Ingress) + fwdRule, err = gce.GetRegionForwardingRule(lbName, gce.region) + if err != nil { + t.Errorf("gce.GetRegionForwardingRule(%q, %q) = %v, want nil", lbName, gce.region, err) + } + if fwdRule.IPProtocol != "UDP" { + t.Errorf("Unexpected protocol value %s, expected UDP", fwdRule.IPProtocol) + } + + // Delete the service + err = gce.EnsureLoadBalancerDeleted(context.Background(), vals.ClusterName, svc) + if err != nil { + t.Errorf("Unexpected error %v", err) + } + assertInternalLbResourcesDeleted(t, gce, svc, vals, true) +}