Merge pull request #81500 from feiskyer/fix-81496

Get location and subscriptionID from IMDS when useInstanceMetadata is true
This commit is contained in:
Kubernetes Prow Robot 2019-08-16 19:10:20 -07:00 committed by GitHub
commit 667ea63ec2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 97 additions and 89 deletions

View File

@ -50,8 +50,8 @@ type NetworkData struct {
// IPAddress represents IP address information.
type IPAddress struct {
PrivateIP string `json:"privateIPAddress"`
PublicIP string `json:"publicIPAddress"`
PrivateIP string `json:"privateIpAddress"`
PublicIP string `json:"publicIpAddress"`
}
// Subnet represents subnet information.
@ -62,6 +62,7 @@ type Subnet struct {
// ComputeMetadata represents compute information
type ComputeMetadata struct {
Environment string `json:"azEnvironment,omitempty"`
SKU string `json:"sku,omitempty"`
Name string `json:"name,omitempty"`
Zone string `json:"zone,omitempty"`
@ -72,6 +73,7 @@ type ComputeMetadata struct {
UpdateDomain string `json:"platformUpdateDomain,omitempty"`
ResourceGroup string `json:"resourceGroupName,omitempty"`
VMScaleSetName string `json:"vmScaleSetName,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
}
// InstanceMetadata represents instance information.
@ -111,7 +113,7 @@ func (ims *InstanceMetadataService) getInstanceMetadata(key string) (interface{}
q := req.URL.Query()
q.Add("format", "json")
q.Add("api-version", "2017-12-01")
q.Add("api-version", "2019-03-11")
req.URL.RawQuery = q.Encode()
client := &http.Client{}

View File

@ -281,12 +281,13 @@ func (az *Cloud) InstanceID(ctx context.Context, name types.NodeName) (string, e
return "", fmt.Errorf("no credentials provided for Azure cloud provider")
}
// Get resource group name.
// Get resource group name and subscription ID.
resourceGroup := strings.ToLower(metadata.Compute.ResourceGroup)
subscriptionID := strings.ToLower(metadata.Compute.SubscriptionID)
// Compose instanceID based on nodeName for standard instance.
if az.VMType == vmTypeStandard {
return az.getStandardMachineID(resourceGroup, nodeName), nil
if metadata.Compute.VMScaleSetName == "" {
return az.getStandardMachineID(subscriptionID, resourceGroup, nodeName), nil
}
// Get scale set name and instanceID from vmName for vmss.
@ -294,12 +295,12 @@ func (az *Cloud) InstanceID(ctx context.Context, name types.NodeName) (string, e
if err != nil {
if err == ErrorNotVmssInstance {
// Compose machineID for standard Node.
return az.getStandardMachineID(resourceGroup, nodeName), nil
return az.getStandardMachineID(subscriptionID, resourceGroup, nodeName), nil
}
return "", err
}
// Compose instanceID based on ssName and instanceID for vmss instance.
return az.getVmssMachineID(resourceGroup, ssName, instanceID), nil
return az.getVmssMachineID(subscriptionID, resourceGroup, ssName, instanceID), nil
}
return az.vmSet.GetInstanceIDByNodeName(nodeName)

View File

@ -81,6 +81,7 @@ func setTestVirtualMachines(c *Cloud, vmList map[string]string, isDataDisksFull
func TestInstanceID(t *testing.T) {
cloud := getTestCloud()
cloud.Config.UseInstanceMetadata = true
testcases := []struct {
name string
@ -120,7 +121,7 @@ func TestInstanceID(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, fmt.Sprintf(`{"compute":{"name":"%s"}}`, test.metadataName))
fmt.Fprintf(w, fmt.Sprintf(`{"compute":{"name":"%s","subscriptionId":"subscription","resourceGroupName":"rg"}}`, test.metadataName))
}))
go func() {
http.Serve(listener, mux)
@ -214,7 +215,7 @@ func TestInstanceShutdownByProviderID(t *testing.T) {
for _, test := range testcases {
cloud := getTestCloud()
setTestVirtualMachines(cloud, test.vmList, false)
providerID := "azure://" + cloud.getStandardMachineID("rg", test.nodeName)
providerID := "azure://" + cloud.getStandardMachineID("subscription", "rg", test.nodeName)
hasShutdown, err := cloud.InstanceShutdownByProviderID(context.Background(), providerID)
if test.expectError {
if err == nil {

View File

@ -314,7 +314,7 @@ func (c *Cloud) GetAzureDiskLabels(diskURI string) (map[string]string, error) {
return nil, fmt.Errorf("failed to parse zone %v for AzureDisk %v: %v", zones, diskName, err)
}
zone := c.makeZone(zoneID)
zone := c.makeZone(c.Location, zoneID)
klog.V(4).Infof("Got zone %q for Azure disk %q", zone, diskName)
labels := map[string]string{
v1.LabelZoneRegion: c.Location,

View File

@ -66,10 +66,10 @@ var nicResourceGroupRE = regexp.MustCompile(`.*/subscriptions/(?:.*)/resourceGro
var publicIPResourceGroupRE = regexp.MustCompile(`.*/subscriptions/(?:.*)/resourceGroups/(.+)/providers/Microsoft.Network/publicIPAddresses/(?:.*)`)
// getStandardMachineID returns the full identifier of a virtual machine.
func (az *Cloud) getStandardMachineID(resourceGroup, machineName string) string {
func (az *Cloud) getStandardMachineID(subscriptionID, resourceGroup, machineName string) string {
return fmt.Sprintf(
machineIDTemplate,
az.SubscriptionID,
subscriptionID,
strings.ToLower(resourceGroup),
machineName)
}
@ -413,7 +413,7 @@ func (as *availabilitySet) GetZoneByNodeName(name string) (cloudprovider.Zone, e
return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone %q: %v", zones, err)
}
failureDomain = as.makeZone(zoneID)
failureDomain = as.makeZone(to.String(vm.Location), zoneID)
} else {
// Availability zone is not used for the node, falling back to fault domain.
failureDomain = strconv.Itoa(int(*vm.VirtualMachineProperties.InstanceView.PlatformFaultDomain))
@ -421,7 +421,7 @@ func (as *availabilitySet) GetZoneByNodeName(name string) (cloudprovider.Zone, e
zone := cloudprovider.Zone{
FailureDomain: failureDomain,
Region: *(vm.Location),
Region: to.String(vm.Location),
}
return zone, nil
}

View File

@ -21,8 +21,6 @@ import (
"context"
"fmt"
"math"
"net"
"net/http"
"strings"
"testing"
@ -1726,70 +1724,6 @@ func validateEmptyConfig(t *testing.T, config string) {
}
}
func TestGetZone(t *testing.T) {
cloud := &Cloud{
Config: Config{
Location: "eastus",
UseInstanceMetadata: true,
},
}
testcases := []struct {
name string
zone string
faultDomain string
expected string
}{
{
name: "GetZone should get real zone if only node's zone is set",
zone: "1",
expected: "eastus-1",
},
{
name: "GetZone should get real zone if both node's zone and FD are set",
zone: "1",
faultDomain: "99",
expected: "eastus-1",
},
{
name: "GetZone should get faultDomain if node's zone isn't set",
faultDomain: "99",
expected: "99",
},
}
for _, test := range testcases {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}
mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, fmt.Sprintf(`{"compute":{"zone":"%s", "platformFaultDomain":"%s"}}`, test.zone, test.faultDomain))
}))
go func() {
http.Serve(listener, mux)
}()
defer listener.Close()
cloud.metadata, err = NewInstanceMetadataService("http://" + listener.Addr().String() + "/")
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}
zone, err := cloud.GetZone(context.Background())
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}
if zone.FailureDomain != test.expected {
t.Errorf("Test [%s] unexpected zone: %s, expected %q", test.name, zone.FailureDomain, test.expected)
}
if zone.Region != cloud.Location {
t.Errorf("Test [%s] unexpected region: %s, expected: %s", test.name, zone.Region, cloud.Location)
}
}
}
func TestGetNodeNameByProviderID(t *testing.T) {
az := getTestCloud()
providers := []struct {

View File

@ -280,7 +280,7 @@ func (ss *scaleSet) GetZoneByNodeName(name string) (cloudprovider.Zone, error) {
return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone %q: %v", zones, err)
}
failureDomain = ss.makeZone(zoneID)
failureDomain = ss.makeZone(to.String(vm.Location), zoneID)
} else if vm.InstanceView != nil && vm.InstanceView.PlatformFaultDomain != nil {
// Availability zone is not used for the node, falling back to fault domain.
failureDomain = strconv.Itoa(int(*vm.InstanceView.PlatformFaultDomain))
@ -288,7 +288,7 @@ func (ss *scaleSet) GetZoneByNodeName(name string) (cloudprovider.Zone, error) {
return cloudprovider.Zone{
FailureDomain: failureDomain,
Region: *vm.Location,
Region: to.String(vm.Location),
}, nil
}
@ -399,10 +399,10 @@ func (ss *scaleSet) getPrimaryInterfaceID(machine compute.VirtualMachineScaleSet
}
// getVmssMachineID returns the full identifier of a vmss virtual machine.
func (az *Cloud) getVmssMachineID(resourceGroup, scaleSetName, instanceID string) string {
func (az *Cloud) getVmssMachineID(subscriptionID, resourceGroup, scaleSetName, instanceID string) string {
return fmt.Sprintf(
vmssMachineIDTemplate,
az.SubscriptionID,
subscriptionID,
strings.ToLower(resourceGroup),
scaleSetName,
instanceID)

View File

@ -29,8 +29,8 @@ import (
)
// makeZone returns the zone value in format of <region>-<zone-id>.
func (az *Cloud) makeZone(zoneID int) string {
return fmt.Sprintf("%s-%d", strings.ToLower(az.Location), zoneID)
func (az *Cloud) makeZone(location string, zoneID int) string {
return fmt.Sprintf("%s-%d", strings.ToLower(location), zoneID)
}
// isAvailabilityZone returns true if the zone is in format of <region>-<zone-id>.
@ -57,16 +57,18 @@ func (az *Cloud) GetZone(ctx context.Context) (cloudprovider.Zone, error) {
}
if metadata.Compute == nil {
az.metadata.imsCache.Delete(metadataCacheKey)
return cloudprovider.Zone{}, fmt.Errorf("failure of getting compute information from instance metadata")
}
zone := ""
location := metadata.Compute.Location
if metadata.Compute.Zone != "" {
zoneID, err := strconv.Atoi(metadata.Compute.Zone)
if err != nil {
return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone ID %q: %v", metadata.Compute.Zone, err)
}
zone = az.makeZone(zoneID)
zone = az.makeZone(location, zoneID)
} else {
klog.V(3).Infof("Availability zone is not enabled for the node, falling back to fault domain")
zone = metadata.Compute.FaultDomain
@ -74,7 +76,7 @@ func (az *Cloud) GetZone(ctx context.Context) (cloudprovider.Zone, error) {
return cloudprovider.Zone{
FailureDomain: zone,
Region: az.Location,
Region: location,
}, nil
}
// if UseInstanceMetadata is false, get Zone name by calling ARM

View File

@ -17,6 +17,10 @@ limitations under the License.
package azure
import (
"context"
"fmt"
"net"
"net/http"
"testing"
)
@ -71,3 +75,67 @@ func TestGetZoneID(t *testing.T) {
}
}
}
func TestGetZone(t *testing.T) {
cloud := &Cloud{
Config: Config{
Location: "eastus",
UseInstanceMetadata: true,
},
}
testcases := []struct {
name string
zone string
faultDomain string
expected string
}{
{
name: "GetZone should get real zone if only node's zone is set",
zone: "1",
expected: "eastus-1",
},
{
name: "GetZone should get real zone if both node's zone and FD are set",
zone: "1",
faultDomain: "99",
expected: "eastus-1",
},
{
name: "GetZone should get faultDomain if node's zone isn't set",
faultDomain: "99",
expected: "99",
},
}
for _, test := range testcases {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}
mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, fmt.Sprintf(`{"compute":{"zone":"%s", "platformFaultDomain":"%s", "location":"eastus"}}`, test.zone, test.faultDomain))
}))
go func() {
http.Serve(listener, mux)
}()
defer listener.Close()
cloud.metadata, err = NewInstanceMetadataService("http://" + listener.Addr().String() + "/")
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}
zone, err := cloud.GetZone(context.Background())
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}
if zone.FailureDomain != test.expected {
t.Errorf("Test [%s] unexpected zone: %s, expected %q", test.name, zone.FailureDomain, test.expected)
}
if zone.Region != cloud.Location {
t.Errorf("Test [%s] unexpected region: %s, expected: %s", test.name, zone.Region, cloud.Location)
}
}
}