diff --git a/pkg/cloudprovider/providers/azure/azure_vmss_test.go b/pkg/cloudprovider/providers/azure/azure_vmss_test.go index e60322a9e22..cdee4aacaca 100644 --- a/pkg/cloudprovider/providers/azure/azure_vmss_test.go +++ b/pkg/cloudprovider/providers/azure/azure_vmss_test.go @@ -24,9 +24,9 @@ import ( "github.com/stretchr/testify/assert" ) -func newTestScaleSet(scaleSetName string, vmList []string) (*scaleSet, error) { +func newTestScaleSet(scaleSetName, zone string, faultDomain int32, vmList []string) (*scaleSet, error) { cloud := getTestCloud() - setTestVirtualMachineCloud(cloud, scaleSetName, vmList) + setTestVirtualMachineCloud(cloud, scaleSetName, zone, faultDomain, vmList) ss, err := newScaleSet(cloud) if err != nil { return nil, err @@ -35,7 +35,7 @@ func newTestScaleSet(scaleSetName string, vmList []string) (*scaleSet, error) { return ss.(*scaleSet), nil } -func setTestVirtualMachineCloud(ss *Cloud, scaleSetName string, vmList []string) { +func setTestVirtualMachineCloud(ss *Cloud, scaleSetName, zone string, faultDomain int32, vmList []string) { virtualMachineScaleSetsClient := newFakeVirtualMachineScaleSetsClient() scaleSets := make(map[string]map[string]compute.VirtualMachineScaleSet) scaleSets["rg"] = map[string]compute.VirtualMachineScaleSet{ @@ -58,7 +58,7 @@ func setTestVirtualMachineCloud(ss *Cloud, scaleSetName string, vmList []string) ID: &nodeName, }, } - ssVMs["rg"][vmName] = compute.VirtualMachineScaleSetVM{ + vmssVM := compute.VirtualMachineScaleSetVM{ VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ OsProfile: &compute.OSProfile{ ComputerName: &nodeName, @@ -66,12 +66,21 @@ func setTestVirtualMachineCloud(ss *Cloud, scaleSetName string, vmList []string) NetworkProfile: &compute.NetworkProfile{ NetworkInterfaces: &networkInterfaces, }, + InstanceView: &compute.VirtualMachineScaleSetVMInstanceView{ + PlatformFaultDomain: &faultDomain, + }, }, ID: &ID, InstanceID: &instanceID, Name: &vmName, Location: &ss.Location, } + + if zone != "" { + zones := []string{zone} + vmssVM.Zones = &zones + } + ssVMs["rg"][vmName] = vmssVM } virtualMachineScaleSetVMsClient.setFakeStore(ssVMs) @@ -141,7 +150,7 @@ func TestGetInstanceIDByNodeName(t *testing.T) { } for _, test := range testCases { - ss, err := newTestScaleSet(test.scaleSet, test.vmList) + ss, err := newTestScaleSet(test.scaleSet, "", 0, test.vmList) assert.NoError(t, err, test.description) real, err := ss.GetInstanceIDByNodeName(test.nodeName) @@ -154,3 +163,56 @@ func TestGetInstanceIDByNodeName(t *testing.T) { assert.Equal(t, test.expected, real, test.description) } } + +func TestGetZoneByNodeName(t *testing.T) { + testCases := []struct { + description string + scaleSet string + vmList []string + nodeName string + zone string + faultDomain int32 + expected string + expectError bool + }{ + { + description: "scaleSet should get faultDomain for non-zoned nodes", + scaleSet: "ss", + vmList: []string{"vmssee6c2000000", "vmssee6c2000001"}, + nodeName: "vmssee6c2000000", + faultDomain: 3, + expected: "3", + }, + { + description: "scaleSet should get availability zone for zoned nodes", + scaleSet: "ss", + vmList: []string{"vmssee6c2000000", "vmssee6c2000001"}, + nodeName: "vmssee6c2000000", + zone: "2", + faultDomain: 3, + expected: "westus-2", + }, + { + description: "scaleSet should return error for non-exist nodes", + scaleSet: "ss", + faultDomain: 3, + vmList: []string{"vmssee6c2000000", "vmssee6c2000001"}, + nodeName: "agente6c2000005", + expectError: true, + }, + } + + for _, test := range testCases { + ss, err := newTestScaleSet(test.scaleSet, test.zone, test.faultDomain, test.vmList) + assert.NoError(t, err, test.description) + + real, err := ss.GetZoneByNodeName(test.nodeName) + if test.expectError { + assert.Error(t, err, test.description) + continue + } + + assert.NoError(t, err, test.description) + assert.Equal(t, test.expected, real.FailureDomain, test.description) + } +}