diff --git a/pkg/cloudprovider/providers/azure/azure_instances.go b/pkg/cloudprovider/providers/azure/azure_instances.go index 1b85334e1aa..bc9685690d7 100644 --- a/pkg/cloudprovider/providers/azure/azure_instances.go +++ b/pkg/cloudprovider/providers/azure/azure_instances.go @@ -19,6 +19,8 @@ package azure import ( "context" "fmt" + "os" + "strings" "k8s.io/api/core/v1" "k8s.io/kubernetes/pkg/cloudprovider" @@ -100,24 +102,57 @@ func (az *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID stri func (az *Cloud) isCurrentInstance(name types.NodeName) (bool, error) { nodeName := mapNodeNameToVMName(name) metadataName, err := az.metadata.Text("instance/compute/name") + if err != nil { + return false, err + } + + if az.VMType == vmTypeVMSS { + // VMSS vmName is not same with hostname, use hostname instead. + metadataName, err = os.Hostname() + if err != nil { + return false, err + } + } + + metadataName = strings.ToLower(metadataName) return (metadataName == nodeName), err } // InstanceID returns the cloud provider ID of the specified instance. // Note that if the instance does not exist or is no longer running, we must return ("", cloudprovider.InstanceNotFound) func (az *Cloud) InstanceID(ctx context.Context, name types.NodeName) (string, error) { + nodeName := mapNodeNameToVMName(name) + if az.UseInstanceMetadata { isLocalInstance, err := az.isCurrentInstance(name) if err != nil { return "", err } - if isLocalInstance { - nodeName := mapNodeNameToVMName(name) - return az.getMachineID(nodeName), nil + + // Not local instance, get instanceID from Azure ARM API. + if !isLocalInstance { + return az.vmSet.GetInstanceIDByNodeName(nodeName) } + + // Compose instanceID based on nodeName for standard instance. + if az.VMType == vmTypeStandard { + return az.getStandardMachineID(nodeName), nil + } + + // Get scale set name and instanceID from vmName for vmss. + metadataName, err := az.metadata.Text("instance/compute/name") + if err != nil { + return "", err + } + ssName, instanceID, err := extractVmssVMName(metadataName) + if err != nil { + return "", err + } + // Compose instanceID based on ssName and instanceID for vmss instance. + return az.getVmssMachineID(ssName, instanceID), nil } - return az.vmSet.GetInstanceIDByNodeName(string(name)) + return az.vmSet.GetInstanceIDByNodeName(nodeName) } // InstanceTypeByProviderID returns the cloudprovider instance type of the node with the specified unique providerID diff --git a/pkg/cloudprovider/providers/azure/azure_standard.go b/pkg/cloudprovider/providers/azure/azure_standard.go index 27af13657b9..bb40eadfae2 100644 --- a/pkg/cloudprovider/providers/azure/azure_standard.go +++ b/pkg/cloudprovider/providers/azure/azure_standard.go @@ -60,8 +60,8 @@ const ( var errNotInVMSet = errors.New("vm is not in the vmset") var providerIDRE = regexp.MustCompile(`^` + CloudProviderName + `://(?:.*)/Microsoft.Compute/virtualMachines/(.+)$`) -// returns the full identifier of a machine -func (az *Cloud) getMachineID(machineName string) string { +// getStandardMachineID returns the full identifier of a virtual machine. +func (az *Cloud) getStandardMachineID(machineName string) string { return fmt.Sprintf( machineIDTemplate, az.SubscriptionID, diff --git a/pkg/cloudprovider/providers/azure/azure_vmss.go b/pkg/cloudprovider/providers/azure/azure_vmss.go index 14dce8391e7..8f41cedbb36 100644 --- a/pkg/cloudprovider/providers/azure/azure_vmss.go +++ b/pkg/cloudprovider/providers/azure/azure_vmss.go @@ -39,7 +39,8 @@ var ( // ErrorNotVmssInstance indicates an instance is not belongint to any vmss. ErrorNotVmssInstance = errors.New("not a vmss instance") - scaleSetNameRE = regexp.MustCompile(`.*/subscriptions/(?:.*)/Microsoft.Compute/virtualMachineScaleSets/(.+)/virtualMachines(?:.*)`) + scaleSetNameRE = regexp.MustCompile(`.*/subscriptions/(?:.*)/Microsoft.Compute/virtualMachineScaleSets/(.+)/virtualMachines(?:.*)`) + vmssMachineIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachineScaleSets/%s/virtualMachines/%s" ) // scaleSet implements VMSet interface for Azure scale set. @@ -735,3 +736,13 @@ func (ss *scaleSet) EnsureBackendPoolDeleted(poolID, vmSetName string) error { return nil } + +// getVmssMachineID returns the full identifier of a vmss virtual machine. +func (az *Cloud) getVmssMachineID(scaleSetName, instanceID string) string { + return fmt.Sprintf( + vmssMachineIDTemplate, + az.SubscriptionID, + az.ResourceGroup, + scaleSetName, + instanceID) +} diff --git a/pkg/cloudprovider/providers/azure/azure_vmss_cache.go b/pkg/cloudprovider/providers/azure/azure_vmss_cache.go index 8aad8425e61..b7f1552d7b7 100644 --- a/pkg/cloudprovider/providers/azure/azure_vmss_cache.go +++ b/pkg/cloudprovider/providers/azure/azure_vmss_cache.go @@ -46,7 +46,7 @@ func (ss *scaleSet) makeVmssVMName(scaleSetName, instanceID string) string { return fmt.Sprintf("%s%s%s", scaleSetName, vmssNameSeparator, instanceID) } -func (ss *scaleSet) extractVmssVMName(name string) (string, string, error) { +func extractVmssVMName(name string) (string, string, error) { ret := strings.Split(name, vmssNameSeparator) if len(ret) != 2 { glog.Errorf("Failed to extract vmssVMName %q", name) @@ -128,7 +128,7 @@ func (ss *scaleSet) newAvailabilitySetNodesCache() (*timedCache, error) { func (ss *scaleSet) newVmssVMCache() (*timedCache, error) { getter := func(key string) (interface{}, error) { // vmssVM name's format is 'scaleSetName_instanceID' - ssName, instanceID, err := ss.extractVmssVMName(key) + ssName, instanceID, err := extractVmssVMName(key) if err != nil { return nil, err } diff --git a/pkg/cloudprovider/providers/azure/azure_vmss_cache_test.go b/pkg/cloudprovider/providers/azure/azure_vmss_cache_test.go index 284a26d340d..ad8bc4798a2 100644 --- a/pkg/cloudprovider/providers/azure/azure_vmss_cache_test.go +++ b/pkg/cloudprovider/providers/azure/azure_vmss_cache_test.go @@ -23,7 +23,6 @@ import ( ) func TestExtractVmssVMName(t *testing.T) { - ss := &scaleSet{} cases := []struct { description string vmName string @@ -50,7 +49,7 @@ func TestExtractVmssVMName(t *testing.T) { } for _, c := range cases { - ssName, instanceID, err := ss.extractVmssVMName(c.vmName) + ssName, instanceID, err := extractVmssVMName(c.vmName) if c.expectError { assert.Error(t, err, c.description) continue