diff --git a/pkg/proxy/winkernel/proxier_test.go b/pkg/proxy/winkernel/proxier_test.go index 4864a6de7b2..b11a02cec12 100644 --- a/pkg/proxy/winkernel/proxier_test.go +++ b/pkg/proxy/winkernel/proxier_test.go @@ -237,16 +237,16 @@ func TestCreateRemoteEndpointOverlay(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName.String()) } else { - if epInfo.hnsID != endpointGuid1 { + if epInfo.hnsID != "EPID-3" { t.Errorf("%v does not match %v", epInfo.hnsID, endpointGuid1) } } - if *proxier.endPointsRefCount[endpointGuid1] <= 0 { + if *proxier.endPointsRefCount["EPID-3"] <= 0 { t.Errorf("RefCount not incremented. Current value: %v", *proxier.endPointsRefCount[endpointGuid1]) } - if *proxier.endPointsRefCount[endpointGuid1] != *epInfo.refCount { + if *proxier.endPointsRefCount["EPID-3"] != *epInfo.refCount { t.Errorf("Global refCount: %v does not match endpoint refCount: %v", *proxier.endPointsRefCount[endpointGuid1], *epInfo.refCount) } } @@ -530,6 +530,7 @@ func TestSharedRemoteEndpointUpdate(t *testing.T) { }} }), ) + proxier.setInitialized(true) proxier.syncProxyRules() ep := proxier.endpointsMap[svcPortName1][0] @@ -759,7 +760,7 @@ func TestCreateDsrLoadBalancer(t *testing.T) { } if len(svcInfo.loadBalancerIngressIPs) == 0 { t.Errorf("svcInfo does not have any loadBalancerIngressIPs, %+v", svcInfo) - } else if svcInfo.loadBalancerIngressIPs[0].healthCheckHnsID != loadbalancerGuid1 { + } else if svcInfo.loadBalancerIngressIPs[0].healthCheckHnsID != "LBID-4" { t.Errorf("The Hns Loadbalancer HealthCheck Id %v does not match %v. ServicePortName %q", svcInfo.loadBalancerIngressIPs[0].healthCheckHnsID, loadbalancerGuid1, svcPortName.String()) } } @@ -917,7 +918,7 @@ func TestEndpointSlice(t *testing.T) { t.Errorf("Failed to cast endpointsInfo %q", svcPortName.String()) } else { - if epInfo.hnsID != endpointGuid1 { + if epInfo.hnsID != "EPID-3" { t.Errorf("Hns EndpointId %v does not match %v. ServicePortName %q", epInfo.hnsID, endpointGuid1, svcPortName.String()) } } diff --git a/pkg/proxy/winkernel/testing/hcnutils_mock.go b/pkg/proxy/winkernel/testing/hcnutils_mock.go index 1c25ff62635..319f2e10c47 100644 --- a/pkg/proxy/winkernel/testing/hcnutils_mock.go +++ b/pkg/proxy/winkernel/testing/hcnutils_mock.go @@ -26,34 +26,37 @@ import ( "github.com/Microsoft/hcsshim/hcn" ) +var ( + epIdCounter int + lbIdCounter int + endpointMap map[string]*hcn.HostComputeEndpoint + loadbalancerMap map[string]*hcn.HostComputeLoadBalancer +) + type HcnMock struct { - epIdCounter int - lbIdCounter int - endpointMap map[string]*hcn.HostComputeEndpoint - loadbalancerMap map[string]*hcn.HostComputeLoadBalancer supportedFeatures hcn.SupportedFeatures network *hcn.HostComputeNetwork } func (hcnObj HcnMock) generateEndpointGuid() (endpointId string, endpointName string) { - hcnObj.epIdCounter++ - endpointId = fmt.Sprintf("EPID-%d", hcnObj.epIdCounter) - endpointName = fmt.Sprintf("EPName-%d", hcnObj.epIdCounter) + epIdCounter++ + endpointId = fmt.Sprintf("EPID-%d", epIdCounter) + endpointName = fmt.Sprintf("EPName-%d", epIdCounter) return } func (hcnObj HcnMock) generateLoadbalancerGuid() (loadbalancerId string) { - hcnObj.lbIdCounter++ - loadbalancerId = fmt.Sprintf("LBID-%d", hcnObj.lbIdCounter) + lbIdCounter++ + loadbalancerId = fmt.Sprintf("LBID-%d", lbIdCounter) return } func NewHcnMock(hnsNetwork *hcn.HostComputeNetwork) *HcnMock { + epIdCounter = 0 + lbIdCounter = 0 + endpointMap = make(map[string]*hcn.HostComputeEndpoint) + loadbalancerMap = make(map[string]*hcn.HostComputeLoadBalancer) return &HcnMock{ - epIdCounter: 0, - lbIdCounter: 0, - endpointMap: make(map[string]*hcn.HostComputeEndpoint), - loadbalancerMap: make(map[string]*hcn.HostComputeLoadBalancer), supportedFeatures: hcn.SupportedFeatures{ Api: hcn.ApiSupport{ V2: true, @@ -79,8 +82,8 @@ func (hcnObj HcnMock) PopulateQueriedEndpoints(epId, hnsId, ipAddress, mac strin MacAddress: mac, } - hcnObj.endpointMap[endpoint.Id] = endpoint - hcnObj.endpointMap[endpoint.Name] = endpoint + endpointMap[endpoint.Id] = endpoint + endpointMap[endpoint.Name] = endpoint } func (hcnObj HcnMock) GetNetworkByName(networkName string) (*hcn.HostComputeNetwork, error) { @@ -93,7 +96,7 @@ func (hcnObj HcnMock) GetNetworkByID(networkID string) (*hcn.HostComputeNetwork, func (hcnObj HcnMock) ListEndpoints() ([]hcn.HostComputeEndpoint, error) { var hcnEPList []hcn.HostComputeEndpoint - for _, ep := range hcnObj.endpointMap { + for _, ep := range endpointMap { hcnEPList = append(hcnEPList, *ep) } return hcnEPList, nil @@ -101,7 +104,7 @@ func (hcnObj HcnMock) ListEndpoints() ([]hcn.HostComputeEndpoint, error) { func (hcnObj HcnMock) ListEndpointsOfNetwork(networkId string) ([]hcn.HostComputeEndpoint, error) { var hcnEPList []hcn.HostComputeEndpoint - for _, ep := range hcnObj.endpointMap { + for _, ep := range endpointMap { if ep.HostComputeNetwork == networkId { hcnEPList = append(hcnEPList, *ep) } @@ -110,7 +113,7 @@ func (hcnObj HcnMock) ListEndpointsOfNetwork(networkId string) ([]hcn.HostComput } func (hcnObj HcnMock) GetEndpointByID(endpointId string) (*hcn.HostComputeEndpoint, error) { - if ep, ok := hcnObj.endpointMap[endpointId]; ok { + if ep, ok := endpointMap[endpointId]; ok { return ep, nil } epNotFoundError := hcn.EndpointNotFoundError{EndpointID: endpointId} @@ -118,7 +121,7 @@ func (hcnObj HcnMock) GetEndpointByID(endpointId string) (*hcn.HostComputeEndpoi } func (hcnObj HcnMock) GetEndpointByName(endpointName string) (*hcn.HostComputeEndpoint, error) { - if ep, ok := hcnObj.endpointMap[endpointName]; ok { + if ep, ok := endpointMap[endpointName]; ok { return ep, nil } epNotFoundError := hcn.EndpointNotFoundError{EndpointName: endpointName} @@ -129,15 +132,16 @@ func (hcnObj HcnMock) CreateEndpoint(network *hcn.HostComputeNetwork, endpoint * if _, err := hcnObj.GetNetworkByID(network.Id); err != nil { return nil, err } - if _, ok := hcnObj.endpointMap[endpoint.Id]; ok { + if _, ok := endpointMap[endpoint.Id]; ok { return nil, fmt.Errorf("endpoint id %s already present", endpoint.Id) } - if _, ok := hcnObj.endpointMap[endpoint.Name]; ok { + if _, ok := endpointMap[endpoint.Name]; ok { return nil, fmt.Errorf("endpoint Name %s already present", endpoint.Name) } endpoint.Id, endpoint.Name = hcnObj.generateEndpointGuid() - hcnObj.endpointMap[endpoint.Id] = endpoint - hcnObj.endpointMap[endpoint.Name] = endpoint + endpoint.HostComputeNetwork = network.Id + endpointMap[endpoint.Id] = endpoint + endpointMap[endpoint.Name] = endpoint return endpoint, nil } @@ -146,24 +150,24 @@ func (hcnObj HcnMock) CreateRemoteEndpoint(network *hcn.HostComputeNetwork, endp } func (hcnObj HcnMock) DeleteEndpoint(endpoint *hcn.HostComputeEndpoint) error { - if _, ok := hcnObj.endpointMap[endpoint.Id]; !ok { + if _, ok := endpointMap[endpoint.Id]; !ok { return hcn.EndpointNotFoundError{EndpointID: endpoint.Id} } - delete(hcnObj.endpointMap, endpoint.Id) - delete(hcnObj.endpointMap, endpoint.Name) + delete(endpointMap, endpoint.Id) + delete(endpointMap, endpoint.Name) return nil } func (hcnObj HcnMock) ListLoadBalancers() ([]hcn.HostComputeLoadBalancer, error) { var hcnLBList []hcn.HostComputeLoadBalancer - for _, lb := range hcnObj.loadbalancerMap { + for _, lb := range loadbalancerMap { hcnLBList = append(hcnLBList, *lb) } return hcnLBList, nil } func (hcnObj HcnMock) GetLoadBalancerByID(loadBalancerId string) (*hcn.HostComputeLoadBalancer, error) { - if lb, ok := hcnObj.loadbalancerMap[loadBalancerId]; ok { + if lb, ok := loadbalancerMap[loadBalancerId]; ok { return lb, nil } lbNotFoundError := hcn.LoadBalancerNotFoundError{LoadBalancerId: loadBalancerId} @@ -171,19 +175,19 @@ func (hcnObj HcnMock) GetLoadBalancerByID(loadBalancerId string) (*hcn.HostCompu } func (hcnObj HcnMock) CreateLoadBalancer(loadBalancer *hcn.HostComputeLoadBalancer) (*hcn.HostComputeLoadBalancer, error) { - if _, ok := hcnObj.loadbalancerMap[loadBalancer.Id]; ok { + if _, ok := loadbalancerMap[loadBalancer.Id]; ok { return nil, fmt.Errorf("LoadBalancer id %s Already Present", loadBalancer.Id) } loadBalancer.Id = hcnObj.generateLoadbalancerGuid() - hcnObj.loadbalancerMap[loadBalancer.Id] = loadBalancer + loadbalancerMap[loadBalancer.Id] = loadBalancer return loadBalancer, nil } func (hcnObj HcnMock) DeleteLoadBalancer(loadBalancer *hcn.HostComputeLoadBalancer) error { - if _, ok := hcnObj.loadbalancerMap[loadBalancer.Id]; !ok { + if _, ok := loadbalancerMap[loadBalancer.Id]; !ok { return hcn.LoadBalancerNotFoundError{LoadBalancerId: loadBalancer.Id} } - delete(hcnObj.loadbalancerMap, loadBalancer.Id) + delete(loadbalancerMap, loadBalancer.Id) return nil } @@ -206,7 +210,7 @@ func (hcnObj HcnMock) DsrSupported() error { } func (hcnObj HcnMock) DeleteAllHnsLoadBalancerPolicy() { - for k := range hcnObj.loadbalancerMap { - delete(hcnObj.loadbalancerMap, k) + for k := range loadbalancerMap { + delete(loadbalancerMap, k) } }