From 610f38cb4a9e030332e23ed9ea981e9c97279678 Mon Sep 17 00:00:00 2001 From: Brendan Burns Date: Fri, 11 Nov 2016 22:55:06 -0800 Subject: [PATCH] Add support for service address ranges to Azure load balancers. --- .../providers/azure/azure_loadbalancer.go | 42 +++++++---- .../providers/azure/azure_test.go | 72 ++++++++++++++----- 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/pkg/cloudprovider/providers/azure/azure_loadbalancer.go b/pkg/cloudprovider/providers/azure/azure_loadbalancer.go index 83f69ef3403..bd26a41c528 100644 --- a/pkg/cloudprovider/providers/azure/azure_loadbalancer.go +++ b/pkg/cloudprovider/providers/azure/azure_loadbalancer.go @@ -474,25 +474,41 @@ func (az *Cloud) reconcileLoadBalancer(lb network.LoadBalancer, pip *network.Pub func (az *Cloud) reconcileSecurityGroup(sg network.SecurityGroup, clusterName string, service *api.Service) (network.SecurityGroup, bool, error) { serviceName := getServiceName(service) wantLb := len(service.Spec.Ports) > 0 - expectedSecurityRules := make([]network.SecurityRule, len(service.Spec.Ports)) + + sourceRanges, err := serviceapi.GetLoadBalancerSourceRanges(service) + if err != nil { + return sg, false, err + } + var sourceAddressPrefixes []string + if sourceRanges == nil || serviceapi.IsAllowAll(sourceRanges) { + sourceAddressPrefixes = []string{"Internet"} + } else { + for _, ip := range sourceRanges { + sourceAddressPrefixes = append(sourceAddressPrefixes, ip.String()) + } + } + expectedSecurityRules := make([]network.SecurityRule, len(service.Spec.Ports)*len(sourceAddressPrefixes)) + for i, port := range service.Spec.Ports { securityRuleName := getRuleName(service, port) _, securityProto, _, err := getProtocolsFromKubernetesProtocol(port.Protocol) if err != nil { return sg, false, err } - - expectedSecurityRules[i] = network.SecurityRule{ - Name: to.StringPtr(securityRuleName), - Properties: &network.SecurityRulePropertiesFormat{ - Protocol: securityProto, - SourcePortRange: to.StringPtr("*"), - DestinationPortRange: to.StringPtr(strconv.Itoa(int(port.Port))), - SourceAddressPrefix: to.StringPtr("Internet"), - DestinationAddressPrefix: to.StringPtr("*"), - Access: network.Allow, - Direction: network.Inbound, - }, + for j := range sourceAddressPrefixes { + ix := i*len(sourceAddressPrefixes) + j + expectedSecurityRules[ix] = network.SecurityRule{ + Name: to.StringPtr(securityRuleName), + Properties: &network.SecurityRulePropertiesFormat{ + Protocol: securityProto, + SourcePortRange: to.StringPtr("*"), + DestinationPortRange: to.StringPtr(strconv.Itoa(int(port.Port))), + SourceAddressPrefix: to.StringPtr(sourceAddressPrefixes[j]), + DestinationAddressPrefix: to.StringPtr("*"), + Access: network.Allow, + Direction: network.Inbound, + }, + } } } diff --git a/pkg/cloudprovider/providers/azure/azure_test.go b/pkg/cloudprovider/providers/azure/azure_test.go index ac3932db887..d48a3e43129 100644 --- a/pkg/cloudprovider/providers/azure/azure_test.go +++ b/pkg/cloudprovider/providers/azure/azure_test.go @@ -187,6 +187,23 @@ func TestReconcileSecurityGroupRemoveServiceRemovesPort(t *testing.T) { validateSecurityGroup(t, sg, svcUpdated) } +func TestReconcileSecurityWithSourceRanges(t *testing.T) { + az := getTestCloud() + svc := getTestService("servicea", 80, 443) + svc.Spec.LoadBalancerSourceRanges = []string{ + "192.168.0.1/24", + "10.0.0.1/32", + } + + sg := getTestSecurityGroup(svc) + sg, _, err := az.reconcileSecurityGroup(sg, testClusterName, &svc) + if err != nil { + t.Errorf("Unexpected error: %q", err) + } + + validateSecurityGroup(t, sg, svc) +} + func getTestCloud() *Cloud { return &Cloud{ Config: Config{ @@ -269,18 +286,30 @@ func getTestLoadBalancer(services ...api.Service) network.LoadBalancer { return lb } +func getServiceSourceRanges(service *api.Service) []string { + if len(service.Spec.LoadBalancerSourceRanges) == 0 { + return []string{"Internet"} + } + return service.Spec.LoadBalancerSourceRanges +} + func getTestSecurityGroup(services ...api.Service) network.SecurityGroup { rules := []network.SecurityRule{} for _, service := range services { for _, port := range service.Spec.Ports { ruleName := getRuleName(&service, port) - rules = append(rules, network.SecurityRule{ - Name: to.StringPtr(ruleName), - Properties: &network.SecurityRulePropertiesFormat{ - DestinationPortRange: to.StringPtr(fmt.Sprintf("%d", port.Port)), - }, - }) + + sources := getServiceSourceRanges(&service) + for _, src := range sources { + rules = append(rules, network.SecurityRule{ + Name: to.StringPtr(ruleName), + Properties: &network.SecurityRulePropertiesFormat{ + SourceAddressPrefix: to.StringPtr(src), + DestinationPortRange: to.StringPtr(fmt.Sprintf("%d", port.Port)), + }, + }) + } } } @@ -344,7 +373,7 @@ func validateLoadBalancer(t *testing.T, loadBalancer network.LoadBalancer, servi lenRules := len(*loadBalancer.Properties.LoadBalancingRules) if lenRules != expectedRuleCount { - t.Errorf("Expected the loadbalancer to have %d rules. Found %d.", expectedRuleCount, lenRules) + t.Errorf("Expected the loadbalancer to have %d rules. Found %d.\n%v", expectedRuleCount, lenRules, loadBalancer.Properties.LoadBalancingRules) } lenProbes := len(*loadBalancer.Properties.Probes) if lenProbes != expectedRuleCount { @@ -356,25 +385,30 @@ func validateSecurityGroup(t *testing.T, securityGroup network.SecurityGroup, se expectedRuleCount := 0 for _, svc := range services { for _, wantedRule := range svc.Spec.Ports { - expectedRuleCount++ - wantedRuleName := getRuleName(&svc, wantedRule) - foundRule := false - for _, actualRule := range *securityGroup.Properties.SecurityRules { - if strings.EqualFold(*actualRule.Name, wantedRuleName) && - *actualRule.Properties.DestinationPortRange == fmt.Sprintf("%d", wantedRule.Port) { - foundRule = true - break + sources := getServiceSourceRanges(&svc) + + for _, source := range sources { + expectedRuleCount++ + wantedRuleName := getRuleName(&svc, wantedRule) + foundRule := false + for _, actualRule := range *securityGroup.Properties.SecurityRules { + if strings.EqualFold(*actualRule.Name, wantedRuleName) && + *actualRule.Properties.SourceAddressPrefix == source && + *actualRule.Properties.DestinationPortRange == fmt.Sprintf("%d", wantedRule.Port) { + foundRule = true + break + } + } + if !foundRule { + t.Errorf("Expected security group rule but didn't find it: %q", wantedRuleName) } - } - if !foundRule { - t.Errorf("Expected security group rule but didn't find it: %q", wantedRuleName) } } } lenRules := len(*securityGroup.Properties.SecurityRules) if lenRules != expectedRuleCount { - t.Errorf("Expected the loadbalancer to have %d rules. Found %d.", expectedRuleCount, lenRules) + t.Errorf("Expected the loadbalancer to have %d rules. Found %d.\n", expectedRuleCount, lenRules) } }