From 7dea3d6c3bf924e40089b8d3656e8eaa100a29fc Mon Sep 17 00:00:00 2001 From: Prince Pereira Date: Tue, 22 Aug 2023 16:20:18 +0530 Subject: [PATCH] New mock test framework for windows kubeproxy. --- pkg/proxy/winkernel/hcnutils.go | 135 ++++++++++ pkg/proxy/winkernel/hns.go | 38 +-- pkg/proxy/winkernel/hns_test.go | 20 +- pkg/proxy/winkernel/proxier.go | 57 ++--- pkg/proxy/winkernel/proxier_test.go | 250 +++++++++---------- pkg/proxy/winkernel/testing/hcnutils_mock.go | 212 ++++++++++++++++ 6 files changed, 511 insertions(+), 201 deletions(-) create mode 100644 pkg/proxy/winkernel/hcnutils.go create mode 100644 pkg/proxy/winkernel/testing/hcnutils_mock.go diff --git a/pkg/proxy/winkernel/hcnutils.go b/pkg/proxy/winkernel/hcnutils.go new file mode 100644 index 00000000000..ec6cf81651d --- /dev/null +++ b/pkg/proxy/winkernel/hcnutils.go @@ -0,0 +1,135 @@ +//go:build windows +// +build windows + +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package winkernel + +import ( + "github.com/Microsoft/hcsshim" + "github.com/Microsoft/hcsshim/hcn" + "k8s.io/klog/v2" +) + +type HcnService interface { + // Network functions + GetNetworkByName(networkName string) (*hcn.HostComputeNetwork, error) + GetNetworkByID(networkID string) (*hcn.HostComputeNetwork, error) + // Endpoint functions + ListEndpoints() ([]hcn.HostComputeEndpoint, error) + ListEndpointsOfNetwork(networkId string) ([]hcn.HostComputeEndpoint, error) + GetEndpointByID(endpointId string) (*hcn.HostComputeEndpoint, error) + GetEndpointByName(endpointName string) (*hcn.HostComputeEndpoint, error) + CreateEndpoint(network *hcn.HostComputeNetwork, endpoint *hcn.HostComputeEndpoint) (*hcn.HostComputeEndpoint, error) + CreateRemoteEndpoint(network *hcn.HostComputeNetwork, endpoint *hcn.HostComputeEndpoint) (*hcn.HostComputeEndpoint, error) + DeleteEndpoint(endpoint *hcn.HostComputeEndpoint) error + // LoadBalancer functions + ListLoadBalancers() ([]hcn.HostComputeLoadBalancer, error) + GetLoadBalancerByID(loadBalancerId string) (*hcn.HostComputeLoadBalancer, error) + CreateLoadBalancer(loadBalancer *hcn.HostComputeLoadBalancer) (*hcn.HostComputeLoadBalancer, error) + DeleteLoadBalancer(loadBalancer *hcn.HostComputeLoadBalancer) error + // Features functions + GetSupportedFeatures() hcn.SupportedFeatures + Ipv6DualStackSupported() error + DsrSupported() error + // Policy functions + DeleteAllHnsLoadBalancerPolicy() +} + +type hcnImpl struct{} + +func newHcnImpl() hcnImpl { + return hcnImpl{} +} + +func (hcnObj hcnImpl) GetNetworkByName(networkName string) (*hcn.HostComputeNetwork, error) { + return hcn.GetNetworkByName(networkName) +} + +func (hcnObj hcnImpl) GetNetworkByID(networkID string) (*hcn.HostComputeNetwork, error) { + return hcn.GetNetworkByID(networkID) +} + +func (hcnObj hcnImpl) ListEndpoints() ([]hcn.HostComputeEndpoint, error) { + return hcn.ListEndpoints() +} + +func (hcnObj hcnImpl) ListEndpointsOfNetwork(networkId string) ([]hcn.HostComputeEndpoint, error) { + return hcn.ListEndpointsOfNetwork(networkId) +} + +func (hcnObj hcnImpl) GetEndpointByID(endpointId string) (*hcn.HostComputeEndpoint, error) { + return hcn.GetEndpointByID(endpointId) +} + +func (hcnObj hcnImpl) GetEndpointByName(endpointName string) (*hcn.HostComputeEndpoint, error) { + return hcn.GetEndpointByName(endpointName) +} + +func (hcnObj hcnImpl) CreateEndpoint(network *hcn.HostComputeNetwork, endpoint *hcn.HostComputeEndpoint) (*hcn.HostComputeEndpoint, error) { + return network.CreateEndpoint(endpoint) +} + +func (hcnObj hcnImpl) CreateRemoteEndpoint(network *hcn.HostComputeNetwork, endpoint *hcn.HostComputeEndpoint) (*hcn.HostComputeEndpoint, error) { + return network.CreateRemoteEndpoint(endpoint) +} + +func (hcnObj hcnImpl) DeleteEndpoint(endpoint *hcn.HostComputeEndpoint) error { + return endpoint.Delete() +} + +func (hcnObj hcnImpl) ListLoadBalancers() ([]hcn.HostComputeLoadBalancer, error) { + return hcn.ListLoadBalancers() +} + +func (hcnObj hcnImpl) GetLoadBalancerByID(loadBalancerId string) (*hcn.HostComputeLoadBalancer, error) { + return hcn.GetLoadBalancerByID(loadBalancerId) +} + +func (hcnObj hcnImpl) CreateLoadBalancer(loadBalancer *hcn.HostComputeLoadBalancer) (*hcn.HostComputeLoadBalancer, error) { + return loadBalancer.Create() +} + +func (hcnObj hcnImpl) DeleteLoadBalancer(loadBalancer *hcn.HostComputeLoadBalancer) error { + return loadBalancer.Delete() +} + +func (hcnObj hcnImpl) GetSupportedFeatures() hcn.SupportedFeatures { + return hcn.GetSupportedFeatures() +} + +func (hcnObj hcnImpl) Ipv6DualStackSupported() error { + return hcn.IPv6DualStackSupported() +} + +func (hcnObj hcnImpl) DsrSupported() error { + return hcn.DSRSupported() +} + +func (hcnObj hcnImpl) DeleteAllHnsLoadBalancerPolicy() { + plists, err := hcsshim.HNSListPolicyListRequest() + if err != nil { + return + } + for _, plist := range plists { + klog.V(3).InfoS("Remove policy", "policies", plist) + _, err = plist.Delete() + if err != nil { + klog.ErrorS(err, "Failed to delete policy list") + } + } +} diff --git a/pkg/proxy/winkernel/hns.go b/pkg/proxy/winkernel/hns.go index f7e6dd8fe9b..9fa83e14ef9 100644 --- a/pkg/proxy/winkernel/hns.go +++ b/pkg/proxy/winkernel/hns.go @@ -43,7 +43,9 @@ type HostNetworkService interface { deleteLoadBalancer(hnsID string) error } -type hns struct{} +type hns struct { + hcn HcnService +} var ( // LoadBalancerFlagsIPv6 enables IPV6. @@ -53,7 +55,7 @@ var ( ) func (hns hns) getNetworkByName(name string) (*hnsNetworkInfo, error) { - hnsnetwork, err := hcn.GetNetworkByName(name) + hnsnetwork, err := hns.hcn.GetNetworkByName(name) if err != nil { klog.ErrorS(err, "Error getting network by name") return nil, err @@ -86,12 +88,12 @@ func (hns hns) getNetworkByName(name string) (*hnsNetworkInfo, error) { } func (hns hns) getAllEndpointsByNetwork(networkName string) (map[string]*(endpointsInfo), error) { - hcnnetwork, err := hcn.GetNetworkByName(networkName) + hcnnetwork, err := hns.hcn.GetNetworkByName(networkName) if err != nil { klog.ErrorS(err, "failed to get HNS network by name", "name", networkName) return nil, err } - endpoints, err := hcn.ListEndpointsOfNetwork(hcnnetwork.Id) + endpoints, err := hns.hcn.ListEndpointsOfNetwork(hcnnetwork.Id) if err != nil { return nil, fmt.Errorf("failed to list endpoints: %w", err) } @@ -144,7 +146,7 @@ func (hns hns) getAllEndpointsByNetwork(networkName string) (map[string]*(endpoi } func (hns hns) getEndpointByID(id string) (*endpointsInfo, error) { - hnsendpoint, err := hcn.GetEndpointByID(id) + hnsendpoint, err := hns.hcn.GetEndpointByID(id) if err != nil { return nil, err } @@ -157,13 +159,13 @@ func (hns hns) getEndpointByID(id string) (*endpointsInfo, error) { }, nil } func (hns hns) getEndpointByIpAddress(ip string, networkName string) (*endpointsInfo, error) { - hnsnetwork, err := hcn.GetNetworkByName(networkName) + hnsnetwork, err := hns.hcn.GetNetworkByName(networkName) if err != nil { klog.ErrorS(err, "Error getting network by name") return nil, err } - endpoints, err := hcn.ListEndpoints() + endpoints, err := hns.hcn.ListEndpoints() if err != nil { return nil, fmt.Errorf("failed to list endpoints: %w", err) } @@ -189,7 +191,7 @@ func (hns hns) getEndpointByIpAddress(ip string, networkName string) (*endpoints return nil, fmt.Errorf("Endpoint %v not found on network %s", ip, networkName) } func (hns hns) getEndpointByName(name string) (*endpointsInfo, error) { - hnsendpoint, err := hcn.GetEndpointByName(name) + hnsendpoint, err := hns.hcn.GetEndpointByName(name) if err != nil { return nil, err } @@ -202,7 +204,7 @@ func (hns hns) getEndpointByName(name string) (*endpointsInfo, error) { }, nil } func (hns hns) createEndpoint(ep *endpointsInfo, networkName string) (*endpointsInfo, error) { - hnsNetwork, err := hcn.GetNetworkByName(networkName) + hnsNetwork, err := hns.hcn.GetNetworkByName(networkName) if err != nil { return nil, err } @@ -239,12 +241,12 @@ func (hns hns) createEndpoint(ep *endpointsInfo, networkName string) (*endpoints } hnsEndpoint.Policies = append(hnsEndpoint.Policies, paPolicy) } - createdEndpoint, err = hnsNetwork.CreateRemoteEndpoint(hnsEndpoint) + createdEndpoint, err = hns.hcn.CreateRemoteEndpoint(hnsNetwork, hnsEndpoint) if err != nil { return nil, err } } else { - createdEndpoint, err = hnsNetwork.CreateEndpoint(hnsEndpoint) + createdEndpoint, err = hns.hcn.CreateEndpoint(hnsNetwork, hnsEndpoint) if err != nil { return nil, err } @@ -259,11 +261,11 @@ func (hns hns) createEndpoint(ep *endpointsInfo, networkName string) (*endpoints }, nil } func (hns hns) deleteEndpoint(hnsID string) error { - hnsendpoint, err := hcn.GetEndpointByID(hnsID) + hnsendpoint, err := hns.hcn.GetEndpointByID(hnsID) if err != nil { return err } - err = hnsendpoint.Delete() + err = hns.hcn.DeleteEndpoint(hnsendpoint) if err == nil { klog.V(3).InfoS("Remote endpoint resource deleted", "hnsID", hnsID) } @@ -285,7 +287,7 @@ func findLoadBalancerID(endpoints []endpointsInfo, vip string, protocol, interna } func (hns hns) getAllLoadBalancers() (map[loadBalancerIdentifier]*loadBalancerInfo, error) { - lbs, err := hcn.ListLoadBalancers() + lbs, err := hns.hcn.ListLoadBalancers() var id loadBalancerIdentifier if err != nil { return nil, err @@ -389,7 +391,7 @@ func (hns hns) getLoadBalancer(endpoints []endpointsInfo, flags loadBalancerFlag loadBalancer.HostComputeEndpoints = append(loadBalancer.HostComputeEndpoints, ep.hnsID) } - lb, err := loadBalancer.Create() + lb, err := hns.hcn.CreateLoadBalancer(loadBalancer) if err != nil { return nil, err @@ -405,18 +407,18 @@ func (hns hns) getLoadBalancer(endpoints []endpointsInfo, flags loadBalancerFlag } func (hns hns) deleteLoadBalancer(hnsID string) error { - lb, err := hcn.GetLoadBalancerByID(hnsID) + lb, err := hns.hcn.GetLoadBalancerByID(hnsID) if err != nil { // Return silently return nil } - err = lb.Delete() + err = hns.hcn.DeleteLoadBalancer(lb) if err != nil { // There is a bug in Windows Server 2019, that can cause the delete call to fail sometimes. We retry one more time. // TODO: The logic in syncProxyRules should be rewritten in the future to better stage and handle a call like this failing using the policyApplied fields. klog.V(1).ErrorS(err, "Error deleting Hns loadbalancer policy resource. Attempting one more time...", "loadBalancer", lb) - return lb.Delete() + return hns.hcn.DeleteLoadBalancer(lb) } return err } diff --git a/pkg/proxy/winkernel/hns_test.go b/pkg/proxy/winkernel/hns_test.go index 66fb19e0b4d..6e674425f78 100644 --- a/pkg/proxy/winkernel/hns_test.go +++ b/pkg/proxy/winkernel/hns_test.go @@ -48,7 +48,7 @@ const ( ) func TestGetNetworkByName(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) network, err := hns.getNetworkByName(Network.Name) @@ -66,7 +66,7 @@ func TestGetNetworkByName(t *testing.T) { } func TestGetAllEndpointsByNetwork(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) ipv4Config := &hcn.IpConfig{ @@ -111,7 +111,7 @@ func TestGetAllEndpointsByNetwork(t *testing.T) { } func TestGetEndpointByID(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) ipConfig := &hcn.IpConfig{ @@ -150,7 +150,7 @@ func TestGetEndpointByID(t *testing.T) { } func TestGetEndpointByIpAddressAndName(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) ipConfig := &hcn.IpConfig{ @@ -200,7 +200,7 @@ func TestGetEndpointByIpAddressAndName(t *testing.T) { } func TestCreateEndpointLocal(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) endpoint := &endpointsInfo{ @@ -238,7 +238,7 @@ func TestCreateEndpointLocal(t *testing.T) { } func TestCreateEndpointRemote(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) providerAddress := epPaAddress @@ -281,7 +281,7 @@ func TestCreateEndpointRemote(t *testing.T) { } func TestDeleteEndpoint(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) ipConfig := &hcn.IpConfig{ @@ -316,7 +316,7 @@ func TestDeleteEndpoint(t *testing.T) { } func TestGetLoadBalancerExisting(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) lbs := make(map[loadBalancerIdentifier]*(loadBalancerInfo)) @@ -389,7 +389,7 @@ func TestGetLoadBalancerExisting(t *testing.T) { } func TestGetLoadBalancerNew(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) // We keep this empty to ensure we test for new load balancer creation. lbs := make(map[loadBalancerIdentifier]*(loadBalancerInfo)) @@ -441,7 +441,7 @@ func TestGetLoadBalancerNew(t *testing.T) { } func TestDeleteLoadBalancer(t *testing.T) { - hns := hns{} + hns := hns{hcn: newHcnImpl()} Network := mustTestNetwork(t) ipConfig := &hcn.IpConfig{ diff --git a/pkg/proxy/winkernel/proxier.go b/pkg/proxy/winkernel/proxier.go index 75011f94566..642a60cdece 100644 --- a/pkg/proxy/winkernel/proxier.go +++ b/pkg/proxy/winkernel/proxier.go @@ -152,11 +152,13 @@ const ( MAX_COUNT_STALE_LOADBALANCERS = 20 ) -func newHostNetworkService() (HostNetworkService, hcn.SupportedFeatures) { +func newHostNetworkService(hcnImpl HcnService) (HostNetworkService, hcn.SupportedFeatures) { var h HostNetworkService - supportedFeatures := hcn.GetSupportedFeatures() + supportedFeatures := hcnImpl.GetSupportedFeatures() if supportedFeatures.Api.V2 { - h = hns{} + h = hns{ + hcn: hcnImpl, + } } else { panic("Windows HNS Api V2 required. This version of windows does not support API V2") } @@ -234,8 +236,9 @@ type StackCompatTester interface { type DualStackCompatTester struct{} func (t DualStackCompatTester) DualStackCompatible(networkName string) bool { + hcnImpl := newHcnImpl() // First tag of hcsshim that has a proper check for dual stack support is v0.8.22 due to a bug. - if err := hcn.IPv6DualStackSupported(); err != nil { + if err := hcnImpl.Ipv6DualStackSupported(); err != nil { // Hcn *can* fail the query to grab the version of hcn itself (which this call will do internally before parsing // to see if dual stack is supported), but the only time this can happen, at least that can be discerned, is if the host // is pre-1803 and hcn didn't exist. hcsshim should truthfully return a known error if this happened that we can @@ -248,7 +251,7 @@ func (t DualStackCompatTester) DualStackCompatible(networkName string) bool { } // check if network is using overlay - hns, _ := newHostNetworkService() + hns, _ := newHostNetworkService(hcnImpl) networkName, err := getNetworkName(networkName) if err != nil { klog.ErrorS(err, "Unable to determine dual-stack status, falling back to single-stack") @@ -535,7 +538,8 @@ func (proxier *Proxier) newServiceInfo(port *v1.ServicePort, service *v1.Service if service.Spec.InternalTrafficPolicy != nil { internalTrafficLocal = *service.Spec.InternalTrafficPolicy == v1.ServiceInternalTrafficPolicyLocal } - err := hcn.DSRSupported() + hcnImpl := proxier.hcn + err := hcnImpl.DsrSupported() if err != nil { preserveDIP = false localTrafficDSR = false @@ -621,6 +625,7 @@ type Proxier struct { healthzServer healthcheck.ProxierHealthUpdater hns HostNetworkService + hcn HcnService network hnsNetworkInfo sourceVip string hostMac string @@ -695,14 +700,15 @@ func NewProxier( nodePortAddresses := proxyutil.NewNodePortAddresses(ipFamily, nil) serviceHealthServer := healthcheck.NewServiceHealthServer(hostname, recorder, nodePortAddresses, healthzServer) - hns, supportedFeatures := newHostNetworkService() + hcnImpl := newHcnImpl() + hns, supportedFeatures := newHostNetworkService(hcnImpl) hnsNetworkName, err := getNetworkName(config.NetworkName) if err != nil { return nil, err } klog.V(3).InfoS("Cleaning up old HNS policy lists") - deleteAllHnsLoadBalancerPolicy() + hcnImpl.DeleteAllHnsLoadBalancerPolicy() // Get HNS network information hnsNetworkInfo, err := getNetworkInfo(hns, hnsNetworkName) @@ -725,7 +731,8 @@ func NewProxier( if isDSR && !utilfeature.DefaultFeatureGate.Enabled(kubefeatures.WinDSR) { return nil, fmt.Errorf("WinDSR feature gate not enabled") } - err = hcn.DSRSupported() + + err = hcnImpl.DsrSupported() if isDSR && err != nil { return nil, err } @@ -781,6 +788,7 @@ func NewProxier( serviceHealthServer: serviceHealthServer, healthzServer: healthzServer, hns: hns, + hcn: hcnImpl, network: *hnsNetworkInfo, sourceVip: sourceVip, hostMac: hostMac, @@ -841,7 +849,7 @@ func NewDualStackProxier( // It returns true if an error was encountered. Errors are logged. func CleanupLeftovers() (encounteredError bool) { // Delete all Hns Load Balancer Policies - deleteAllHnsLoadBalancerPolicy() + newHcnImpl().DeleteAllHnsLoadBalancerPolicy() // TODO // Delete all Hns Remote endpoints @@ -926,35 +934,6 @@ func (svcInfo *serviceInfo) deleteLoadBalancerPolicy(mapStaleLoadbalancer map[st } } -func deleteAllHnsLoadBalancerPolicy() { - plists, err := hcsshim.HNSListPolicyListRequest() - if err != nil { - return - } - for _, plist := range plists { - klog.V(3).InfoS("Remove policy", "policies", plist) - _, err = plist.Delete() - if err != nil { - klog.ErrorS(err, "Failed to delete policy list") - } - } - -} - -func getHnsNetworkInfo(hnsNetworkName string) (*hnsNetworkInfo, error) { - hnsnetwork, err := hcsshim.GetHNSNetworkByName(hnsNetworkName) - if err != nil { - klog.ErrorS(err, "Failed to get HNS Network by name") - return nil, err - } - - return &hnsNetworkInfo{ - id: hnsnetwork.Id, - name: hnsnetwork.Name, - networkType: hnsnetwork.Type, - }, nil -} - // Sync is called to synchronize the proxier state to hns as soon as possible. func (proxier *Proxier) Sync() { if proxier.healthzServer != nil { diff --git a/pkg/proxy/winkernel/proxier_test.go b/pkg/proxy/winkernel/proxier_test.go index c6b5e35d366..4864a6de7b2 100644 --- a/pkg/proxy/winkernel/proxier_test.go +++ b/pkg/proxy/winkernel/proxier_test.go @@ -20,12 +20,14 @@ limitations under the License. package winkernel import ( + "encoding/json" "fmt" "net" "strings" "testing" "time" + "github.com/Microsoft/hcsshim/hcn" v1 "k8s.io/api/core/v1" discovery "k8s.io/api/discovery/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -33,26 +35,57 @@ import ( "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/kubernetes/pkg/proxy" "k8s.io/kubernetes/pkg/proxy/healthcheck" + fakehcn "k8s.io/kubernetes/pkg/proxy/winkernel/testing" netutils "k8s.io/utils/net" "k8s.io/utils/pointer" ) const ( testHostName = "test-hostname" + testNetwork = "TestNetwork" + ipAddress = "10.0.0.1" + prefixLen = 24 macAddress = "00-11-22-33-44-55" clusterCIDR = "192.168.1.0/24" destinationPrefix = "192.168.2.0/24" providerAddress = "10.0.0.3" guid = "123ABC" + endpointGuid1 = "EPID-1" + loadbalancerGuid1 = "LBID-1" + endpointLocal = "EP-LOCAL" + endpointGw = "EP-GW" + epIpAddressGw = "192.168.2.1" + epMacAddressGw = "00-11-22-33-44-66" ) -type fakeHNS struct{} +func newHnsNetwork(networkInfo *hnsNetworkInfo) *hcn.HostComputeNetwork { + var policies []hcn.NetworkPolicy + for _, remoteSubnet := range networkInfo.remoteSubnets { + policySettings := hcn.RemoteSubnetRoutePolicySetting{ + DestinationPrefix: remoteSubnet.destinationPrefix, + IsolationId: remoteSubnet.isolationID, + ProviderAddress: remoteSubnet.providerAddress, + DistributedRouterMacAddress: remoteSubnet.drMacAddress, + } + settings, _ := json.Marshal(policySettings) + policy := hcn.NetworkPolicy{ + Type: hcn.RemoteSubnetRoute, + Settings: settings, + } + policies = append(policies, policy) + } -func newFakeHNS() *fakeHNS { - return &fakeHNS{} + network := &hcn.HostComputeNetwork{ + Id: networkInfo.id, + Name: networkInfo.name, + Type: hcn.NetworkType(networkInfo.networkType), + Policies: policies, + } + return network } -func (hns fakeHNS) getNetworkByName(name string) (*hnsNetworkInfo, error) { +func NewFakeProxier(syncPeriod time.Duration, minSyncPeriod time.Duration, clusterCIDR string, hostname string, nodeIP net.IP, networkType string) *Proxier { + sourceVip := "192.168.1.2" var remoteSubnets []*remoteSubnetInfo rs := &remoteSubnetInfo{ destinationPrefix: destinationPrefix, @@ -61,96 +94,32 @@ func (hns fakeHNS) getNetworkByName(name string) (*hnsNetworkInfo, error) { drMacAddress: macAddress, } remoteSubnets = append(remoteSubnets, rs) - return &hnsNetworkInfo{ - id: strings.ToUpper(guid), - name: name, - networkType: NETWORK_TYPE_OVERLAY, - remoteSubnets: remoteSubnets, - }, nil -} - -func (hns fakeHNS) getAllEndpointsByNetwork(networkName string) (map[string]*(endpointsInfo), error) { - return nil, nil -} - -func (hns fakeHNS) getEndpointByID(id string) (*endpointsInfo, error) { - return nil, nil -} - -func (hns fakeHNS) getEndpointByName(name string) (*endpointsInfo, error) { - return &endpointsInfo{ - isLocal: true, - macAddress: macAddress, - hnsID: guid, - hns: hns, - }, nil -} - -func (hns fakeHNS) getAllLoadBalancers() (map[loadBalancerIdentifier]*loadBalancerInfo, error) { - return nil, nil -} - -func (hns fakeHNS) getEndpointByIpAddress(ip string, networkName string) (*endpointsInfo, error) { - _, ipNet, _ := netutils.ParseCIDRSloppy(destinationPrefix) - - if ipNet.Contains(netutils.ParseIPSloppy(ip)) { - return &endpointsInfo{ - ip: ip, - isLocal: false, - macAddress: macAddress, - hnsID: guid, - hns: hns, - }, nil - } - return nil, nil - -} - -func (hns fakeHNS) createEndpoint(ep *endpointsInfo, networkName string) (*endpointsInfo, error) { - return &endpointsInfo{ - ip: ep.ip, - isLocal: ep.isLocal, - macAddress: ep.macAddress, - hnsID: guid, - hns: hns, - }, nil -} - -func (hns fakeHNS) deleteEndpoint(hnsID string) error { - return nil -} - -func (hns fakeHNS) getLoadBalancer(endpoints []endpointsInfo, flags loadBalancerFlags, sourceVip string, vip string, protocol uint16, internalPort uint16, externalPort uint16, previousLoadBalancers map[loadBalancerIdentifier]*loadBalancerInfo) (*loadBalancerInfo, error) { - return &loadBalancerInfo{ - hnsID: guid, - }, nil -} - -func (hns fakeHNS) deleteLoadBalancer(hnsID string) error { - return nil -} - -func NewFakeProxier(syncPeriod time.Duration, minSyncPeriod time.Duration, clusterCIDR string, hostname string, nodeIP net.IP, networkType string) *Proxier { - sourceVip := "192.168.1.2" hnsNetworkInfo := &hnsNetworkInfo{ - id: strings.ToUpper(guid), - name: "TestNetwork", - networkType: networkType, + id: strings.ToUpper(guid), + name: testNetwork, + networkType: networkType, + remoteSubnets: remoteSubnets, } + hnsNetwork := newHnsNetwork(hnsNetworkInfo) + hcnMock := fakehcn.NewHcnMock(hnsNetwork) proxier := &Proxier{ - svcPortMap: make(proxy.ServicePortMap), - endpointsMap: make(proxy.EndpointsMap), - clusterCIDR: clusterCIDR, - hostname: testHostName, - nodeIP: nodeIP, - serviceHealthServer: healthcheck.NewFakeServiceHealthServer(), - network: *hnsNetworkInfo, - sourceVip: sourceVip, - hostMac: macAddress, - isDSR: false, - hns: newFakeHNS(), + svcPortMap: make(proxy.ServicePortMap), + endpointsMap: make(proxy.EndpointsMap), + clusterCIDR: clusterCIDR, + hostname: testHostName, + nodeIP: nodeIP, + serviceHealthServer: healthcheck.NewFakeServiceHealthServer(), + network: *hnsNetworkInfo, + sourceVip: sourceVip, + hostMac: macAddress, + isDSR: false, + hns: &hns{ + hcn: hcnMock, + }, + hcn: hcnMock, endPointsRefCount: make(endPointsReferenceCountMap), forwardHealthCheckVip: true, + mapStaleLoadbalancers: make(map[string]bool), } serviceChanges := proxy.NewServiceChangeTracker(proxier.newServiceInfo, v1.IPv4Protocol, nil, proxier.serviceMapChange) @@ -258,6 +227,7 @@ func TestCreateRemoteEndpointOverlay(t *testing.T) { }} }), ) + proxier.setInitialized(true) proxier.syncProxyRules() @@ -267,17 +237,17 @@ func TestCreateRemoteEndpointOverlay(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName.String()) } else { - if epInfo.hnsID != guid { - t.Errorf("%v does not match %v", epInfo.hnsID, guid) + if epInfo.hnsID != endpointGuid1 { + t.Errorf("%v does not match %v", epInfo.hnsID, endpointGuid1) } } - if *proxier.endPointsRefCount[guid] <= 0 { - t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[guid]) + if *proxier.endPointsRefCount[endpointGuid1] <= 0 { + t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[endpointGuid1]) } - if *proxier.endPointsRefCount[guid] != *epInfo.refCount { - t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[guid], *epInfo.refCount) + if *proxier.endPointsRefCount[endpointGuid1] != *epInfo.refCount { + t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[endpointGuid1], *epInfo.refCount) } } @@ -331,17 +301,17 @@ func TestCreateRemoteEndpointL2Bridge(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName.String()) } else { - if epInfo.hnsID != guid { - t.Errorf("%v does not match %v", epInfo.hnsID, guid) + if epInfo.hnsID != endpointGuid1 { + t.Errorf("%v does not match %v", epInfo.hnsID, endpointGuid1) } } - if *proxier.endPointsRefCount[guid] <= 0 { - t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[guid]) + if *proxier.endPointsRefCount[endpointGuid1] <= 0 { + t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[endpointGuid1]) } - if *proxier.endPointsRefCount[guid] != *epInfo.refCount { - t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[guid], *epInfo.refCount) + if *proxier.endPointsRefCount[endpointGuid1] != *epInfo.refCount { + t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[endpointGuid1], *epInfo.refCount) } } func TestSharedRemoteEndpointDelete(t *testing.T) { @@ -424,17 +394,17 @@ func TestSharedRemoteEndpointDelete(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName1.String()) } else { - if epInfo.hnsID != guid { - t.Errorf("%v does not match %v", epInfo.hnsID, guid) + if epInfo.hnsID != endpointGuid1 { + t.Errorf("%v does not match %v", epInfo.hnsID, endpointGuid1) } } - if *proxier.endPointsRefCount[guid] != 2 { - t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[guid]) + if *proxier.endPointsRefCount[endpointGuid1] != 2 { + t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[endpointGuid1]) } - if *proxier.endPointsRefCount[guid] != *epInfo.refCount { - t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[guid], *epInfo.refCount) + if *proxier.endPointsRefCount[endpointGuid1] != *epInfo.refCount { + t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[endpointGuid1], *epInfo.refCount) } proxier.setInitialized(false) @@ -474,8 +444,8 @@ func TestSharedRemoteEndpointDelete(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName1.String()) } else { - if epInfo.hnsID != guid { - t.Errorf("%v does not match %v", epInfo.hnsID, guid) + if epInfo.hnsID != endpointGuid1 { + t.Errorf("%v does not match %v", epInfo.hnsID, endpointGuid1) } } @@ -483,8 +453,8 @@ func TestSharedRemoteEndpointDelete(t *testing.T) { t.Errorf("Incorrect Refcount. Current value: %v", *epInfo.refCount) } - if *proxier.endPointsRefCount[guid] != *epInfo.refCount { - t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[guid], *epInfo.refCount) + if *proxier.endPointsRefCount[endpointGuid1] != *epInfo.refCount { + t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[endpointGuid1], *epInfo.refCount) } } func TestSharedRemoteEndpointUpdate(t *testing.T) { @@ -568,17 +538,17 @@ func TestSharedRemoteEndpointUpdate(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName1.String()) } else { - if epInfo.hnsID != guid { - t.Errorf("%v does not match %v", epInfo.hnsID, guid) + if epInfo.hnsID != endpointGuid1 { + t.Errorf("%v does not match %v", epInfo.hnsID, endpointGuid1) } } - if *proxier.endPointsRefCount[guid] != 2 { - t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[guid]) + if *proxier.endPointsRefCount[endpointGuid1] != 2 { + t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[endpointGuid1]) } - if *proxier.endPointsRefCount[guid] != *epInfo.refCount { - t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[guid], *epInfo.refCount) + if *proxier.endPointsRefCount[endpointGuid1] != *epInfo.refCount { + t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[endpointGuid1], *epInfo.refCount) } proxier.setInitialized(false) @@ -648,8 +618,8 @@ func TestSharedRemoteEndpointUpdate(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName1.String()) } else { - if epInfo.hnsID != guid { - t.Errorf("%v does not match %v", epInfo.hnsID, guid) + if epInfo.hnsID != endpointGuid1 { + t.Errorf("%v does not match %v", epInfo.hnsID, endpointGuid1) } } @@ -657,8 +627,8 @@ func TestSharedRemoteEndpointUpdate(t *testing.T) { t.Errorf("Incorrect refcount. Current value: %v", *epInfo.refCount) } - if *proxier.endPointsRefCount[guid] != *epInfo.refCount { - t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[guid], *epInfo.refCount) + if *proxier.endPointsRefCount[endpointGuid1] != *epInfo.refCount { + t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[endpointGuid1], *epInfo.refCount) } } func TestCreateLoadBalancer(t *testing.T) { @@ -713,8 +683,8 @@ func TestCreateLoadBalancer(t *testing.T) { t.Errorf("Failed to cast serviceInfo %q", svcPortName.String()) } else { - if svcInfo.hnsID != guid { - t.Errorf("%v does not match %v", svcInfo.hnsID, guid) + if svcInfo.hnsID != loadbalancerGuid1 { + t.Errorf("%v does not match %v", svcInfo.hnsID, loadbalancerGuid1) } } } @@ -758,6 +728,7 @@ func TestCreateDsrLoadBalancer(t *testing.T) { eps.AddressType = discovery.AddressTypeIPv4 eps.Endpoints = []discovery.Endpoint{{ Addresses: []string{epIpAddressRemote}, + NodeName: pointer.String("testhost"), }} eps.Ports = []discovery.EndpointPort{{ Name: pointer.String(svcPortName.Port), @@ -767,6 +738,10 @@ func TestCreateDsrLoadBalancer(t *testing.T) { }), ) + hcn := (proxier.hcn).(*fakehcn.HcnMock) + proxier.rootHnsEndpointName = endpointGw + hcn.PopulateQueriedEndpoints(endpointLocal, guid, epIpAddressRemote, macAddress, prefixLen) + hcn.PopulateQueriedEndpoints(endpointGw, guid, epIpAddressGw, epMacAddressGw, prefixLen) proxier.setInitialized(true) proxier.syncProxyRules() @@ -776,16 +751,16 @@ func TestCreateDsrLoadBalancer(t *testing.T) { t.Errorf("Failed to cast serviceInfo %q", svcPortName.String()) } else { - if svcInfo.hnsID != guid { - t.Errorf("%v does not match %v", svcInfo.hnsID, guid) + if svcInfo.hnsID != loadbalancerGuid1 { + t.Errorf("%v does not match %v", svcInfo.hnsID, loadbalancerGuid1) } if svcInfo.localTrafficDSR != true { t.Errorf("Failed to create DSR loadbalancer with local traffic policy") } if len(svcInfo.loadBalancerIngressIPs) == 0 { t.Errorf("svcInfo does not have any loadBalancerIngressIPs, %+v", svcInfo) - } else if svcInfo.loadBalancerIngressIPs[0].healthCheckHnsID != guid { - t.Errorf("The Hns Loadbalancer HealthCheck Id %v does not match %v. ServicePortName %q", svcInfo.loadBalancerIngressIPs[0].healthCheckHnsID, guid, svcPortName.String()) + } else if svcInfo.loadBalancerIngressIPs[0].healthCheckHnsID != loadbalancerGuid1 { + t.Errorf("The Hns Loadbalancer HealthCheck Id %v does not match %v. ServicePortName %q", svcInfo.loadBalancerIngressIPs[0].healthCheckHnsID, loadbalancerGuid1, svcPortName.String()) } } } @@ -796,6 +771,7 @@ func TestCreateDsrLoadBalancer(t *testing.T) { func TestClusterIPLBInCreateDsrLoadBalancer(t *testing.T) { syncPeriod := 30 * time.Second proxier := NewFakeProxier(syncPeriod, syncPeriod, clusterCIDR, "testhost", netutils.ParseIPSloppy("10.0.0.1"), NETWORK_TYPE_OVERLAY) + if proxier == nil { t.Error() } @@ -852,8 +828,8 @@ func TestClusterIPLBInCreateDsrLoadBalancer(t *testing.T) { } else { // Checking ClusterIP Loadbalancer is created - if svcInfo.hnsID != guid { - t.Errorf("%v does not match %v", svcInfo.hnsID, guid) + if svcInfo.hnsID != loadbalancerGuid1 { + t.Errorf("%v does not match %v", svcInfo.hnsID, loadbalancerGuid1) } // Verifying NodePort Loadbalancer is not created if svcInfo.nodePorthnsID != "" { @@ -930,8 +906,8 @@ func TestEndpointSlice(t *testing.T) { t.Errorf("Failed to cast serviceInfo %q", svcPortName.String()) } else { - if svcInfo.hnsID != guid { - t.Errorf("The Hns Loadbalancer Id %v does not match %v. ServicePortName %q", svcInfo.hnsID, guid, svcPortName.String()) + if svcInfo.hnsID != loadbalancerGuid1 { + t.Errorf("The Hns Loadbalancer Id %v does not match %v. ServicePortName %q", svcInfo.hnsID, loadbalancerGuid1, svcPortName.String()) } } @@ -941,8 +917,8 @@ func TestEndpointSlice(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName.String()) } else { - if epInfo.hnsID != guid { - t.Errorf("Hns EndpointId %v does not match %v. ServicePortName %q", epInfo.hnsID, guid, svcPortName.String()) + if epInfo.hnsID != endpointGuid1 { + t.Errorf("Hns EndpointId %v does not match %v. ServicePortName %q", epInfo.hnsID, endpointGuid1, svcPortName.String()) } } } @@ -956,7 +932,13 @@ func TestNoopEndpointSlice(t *testing.T) { } func TestFindRemoteSubnetProviderAddress(t *testing.T) { - networkInfo, _ := newFakeHNS().getNetworkByName("TestNetwork") + syncPeriod := 30 * time.Second + proxier := NewFakeProxier(syncPeriod, syncPeriod, clusterCIDR, "testhost", netutils.ParseIPSloppy("10.0.0.1"), NETWORK_TYPE_OVERLAY) + if proxier == nil { + t.Error() + } + + networkInfo, _ := proxier.hns.getNetworkByName(testNetwork) pa := networkInfo.findRemoteSubnetProviderAddress(providerAddress) if pa != providerAddress { diff --git a/pkg/proxy/winkernel/testing/hcnutils_mock.go b/pkg/proxy/winkernel/testing/hcnutils_mock.go new file mode 100644 index 00000000000..1c25ff62635 --- /dev/null +++ b/pkg/proxy/winkernel/testing/hcnutils_mock.go @@ -0,0 +1,212 @@ +//go:build windows +// +build windows + +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "errors" + "fmt" + + "github.com/Microsoft/hcsshim/hcn" +) + +type HcnMock struct { + epIdCounter int + lbIdCounter int + endpointMap map[string]*hcn.HostComputeEndpoint + loadbalancerMap map[string]*hcn.HostComputeLoadBalancer + supportedFeatures hcn.SupportedFeatures + network *hcn.HostComputeNetwork +} + +func (hcnObj HcnMock) generateEndpointGuid() (endpointId string, endpointName string) { + hcnObj.epIdCounter++ + endpointId = fmt.Sprintf("EPID-%d", hcnObj.epIdCounter) + endpointName = fmt.Sprintf("EPName-%d", hcnObj.epIdCounter) + return +} + +func (hcnObj HcnMock) generateLoadbalancerGuid() (loadbalancerId string) { + hcnObj.lbIdCounter++ + loadbalancerId = fmt.Sprintf("LBID-%d", hcnObj.lbIdCounter) + return +} + +func NewHcnMock(hnsNetwork *hcn.HostComputeNetwork) *HcnMock { + return &HcnMock{ + epIdCounter: 0, + lbIdCounter: 0, + endpointMap: make(map[string]*hcn.HostComputeEndpoint), + loadbalancerMap: make(map[string]*hcn.HostComputeLoadBalancer), + supportedFeatures: hcn.SupportedFeatures{ + Api: hcn.ApiSupport{ + V2: true, + }, + DSR: true, + IPv6DualStack: true, + }, + network: hnsNetwork, + } +} + +func (hcnObj HcnMock) PopulateQueriedEndpoints(epId, hnsId, ipAddress, mac string, prefixLen uint8) { + endpoint := &hcn.HostComputeEndpoint{ + Id: epId, + Name: epId, + HostComputeNetwork: hnsId, + IpConfigurations: []hcn.IpConfig{ + { + IpAddress: ipAddress, + PrefixLength: prefixLen, + }, + }, + MacAddress: mac, + } + + hcnObj.endpointMap[endpoint.Id] = endpoint + hcnObj.endpointMap[endpoint.Name] = endpoint +} + +func (hcnObj HcnMock) GetNetworkByName(networkName string) (*hcn.HostComputeNetwork, error) { + return hcnObj.network, nil +} + +func (hcnObj HcnMock) GetNetworkByID(networkID string) (*hcn.HostComputeNetwork, error) { + return hcnObj.network, nil +} + +func (hcnObj HcnMock) ListEndpoints() ([]hcn.HostComputeEndpoint, error) { + var hcnEPList []hcn.HostComputeEndpoint + for _, ep := range hcnObj.endpointMap { + hcnEPList = append(hcnEPList, *ep) + } + return hcnEPList, nil +} + +func (hcnObj HcnMock) ListEndpointsOfNetwork(networkId string) ([]hcn.HostComputeEndpoint, error) { + var hcnEPList []hcn.HostComputeEndpoint + for _, ep := range hcnObj.endpointMap { + if ep.HostComputeNetwork == networkId { + hcnEPList = append(hcnEPList, *ep) + } + } + return hcnEPList, nil +} + +func (hcnObj HcnMock) GetEndpointByID(endpointId string) (*hcn.HostComputeEndpoint, error) { + if ep, ok := hcnObj.endpointMap[endpointId]; ok { + return ep, nil + } + epNotFoundError := hcn.EndpointNotFoundError{EndpointID: endpointId} + return nil, epNotFoundError +} + +func (hcnObj HcnMock) GetEndpointByName(endpointName string) (*hcn.HostComputeEndpoint, error) { + if ep, ok := hcnObj.endpointMap[endpointName]; ok { + return ep, nil + } + epNotFoundError := hcn.EndpointNotFoundError{EndpointName: endpointName} + return nil, epNotFoundError +} + +func (hcnObj HcnMock) CreateEndpoint(network *hcn.HostComputeNetwork, endpoint *hcn.HostComputeEndpoint) (*hcn.HostComputeEndpoint, error) { + if _, err := hcnObj.GetNetworkByID(network.Id); err != nil { + return nil, err + } + if _, ok := hcnObj.endpointMap[endpoint.Id]; ok { + return nil, fmt.Errorf("endpoint id %s already present", endpoint.Id) + } + if _, ok := hcnObj.endpointMap[endpoint.Name]; ok { + return nil, fmt.Errorf("endpoint Name %s already present", endpoint.Name) + } + endpoint.Id, endpoint.Name = hcnObj.generateEndpointGuid() + hcnObj.endpointMap[endpoint.Id] = endpoint + hcnObj.endpointMap[endpoint.Name] = endpoint + return endpoint, nil +} + +func (hcnObj HcnMock) CreateRemoteEndpoint(network *hcn.HostComputeNetwork, endpoint *hcn.HostComputeEndpoint) (*hcn.HostComputeEndpoint, error) { + return hcnObj.CreateEndpoint(network, endpoint) +} + +func (hcnObj HcnMock) DeleteEndpoint(endpoint *hcn.HostComputeEndpoint) error { + if _, ok := hcnObj.endpointMap[endpoint.Id]; !ok { + return hcn.EndpointNotFoundError{EndpointID: endpoint.Id} + } + delete(hcnObj.endpointMap, endpoint.Id) + delete(hcnObj.endpointMap, endpoint.Name) + return nil +} + +func (hcnObj HcnMock) ListLoadBalancers() ([]hcn.HostComputeLoadBalancer, error) { + var hcnLBList []hcn.HostComputeLoadBalancer + for _, lb := range hcnObj.loadbalancerMap { + hcnLBList = append(hcnLBList, *lb) + } + return hcnLBList, nil +} + +func (hcnObj HcnMock) GetLoadBalancerByID(loadBalancerId string) (*hcn.HostComputeLoadBalancer, error) { + if lb, ok := hcnObj.loadbalancerMap[loadBalancerId]; ok { + return lb, nil + } + lbNotFoundError := hcn.LoadBalancerNotFoundError{LoadBalancerId: loadBalancerId} + return nil, lbNotFoundError +} + +func (hcnObj HcnMock) CreateLoadBalancer(loadBalancer *hcn.HostComputeLoadBalancer) (*hcn.HostComputeLoadBalancer, error) { + if _, ok := hcnObj.loadbalancerMap[loadBalancer.Id]; ok { + return nil, fmt.Errorf("LoadBalancer id %s Already Present", loadBalancer.Id) + } + loadBalancer.Id = hcnObj.generateLoadbalancerGuid() + hcnObj.loadbalancerMap[loadBalancer.Id] = loadBalancer + return loadBalancer, nil +} + +func (hcnObj HcnMock) DeleteLoadBalancer(loadBalancer *hcn.HostComputeLoadBalancer) error { + if _, ok := hcnObj.loadbalancerMap[loadBalancer.Id]; !ok { + return hcn.LoadBalancerNotFoundError{LoadBalancerId: loadBalancer.Id} + } + delete(hcnObj.loadbalancerMap, loadBalancer.Id) + return nil +} + +func (hcnObj HcnMock) GetSupportedFeatures() hcn.SupportedFeatures { + return hcnObj.supportedFeatures +} + +func (hcnObj HcnMock) Ipv6DualStackSupported() error { + if hcnObj.supportedFeatures.IPv6DualStack { + return nil + } + return errors.New("IPV6 DualStack Not Supported") +} + +func (hcnObj HcnMock) DsrSupported() error { + if hcnObj.supportedFeatures.DSR { + return nil + } + return errors.New("DSR Not Supported") +} + +func (hcnObj HcnMock) DeleteAllHnsLoadBalancerPolicy() { + for k := range hcnObj.loadbalancerMap { + delete(hcnObj.loadbalancerMap, k) + } +}