Merge pull request #59716 from feiskyer/vmss-disk

Automatic merge from submit-queue (batch tested with PRs 59489, 59716). If you want to cherry-pick this change to another branch, please follow the instructions <a href="https://github.com/kubernetes/community/blob/master/contributors/devel/cherry-picks.md">here</a>.

Add AzureDisk support for vmss nodes

**What this PR does / why we need it**:

This PR adds AzureDisk support for vmss nodes. Changes include

- Upgrade vmss API to 2017-12-01
- Upgrade vmss clients with new version API
- Abstract AzureDisk operations for vmss and vmas
- Added AzureDisk support for vmss
- Unit tests and fake clients fix

**Which issue(s) this PR fixes** *(optional, in `fixes #<issue number>(, fixes #<issue_number>, ...)` format, will close the issue(s) when PR gets merged)*:
Fixes #43287

**Special notes for your reviewer**:

~~Depending on #59652 (the first two commits are from #59652).~~

**Release note**:

```release-note
Add AzureDisk support for vmss nodes
```

Kubernetes-commit: d89e64110aa47d557a4b133c40b38de1b41ef7f7
This commit is contained in:
Kubernetes Publisher 2018-02-14 00:14:34 -08:00
commit 35d357565b
14 changed files with 785 additions and 128 deletions

8
Godeps/Godeps.json generated
View File

@ -16,19 +16,19 @@
}, },
{ {
"ImportPath": "github.com/Azure/go-autorest/autorest", "ImportPath": "github.com/Azure/go-autorest/autorest",
"Rev": "e14a70c556c8e0db173358d1a903dca345a8e75e" "Rev": "d4e6b95c12a08b4de2d48b45d5b4d594e5d32fab"
}, },
{ {
"ImportPath": "github.com/Azure/go-autorest/autorest/adal", "ImportPath": "github.com/Azure/go-autorest/autorest/adal",
"Rev": "e14a70c556c8e0db173358d1a903dca345a8e75e" "Rev": "d4e6b95c12a08b4de2d48b45d5b4d594e5d32fab"
}, },
{ {
"ImportPath": "github.com/Azure/go-autorest/autorest/azure", "ImportPath": "github.com/Azure/go-autorest/autorest/azure",
"Rev": "e14a70c556c8e0db173358d1a903dca345a8e75e" "Rev": "d4e6b95c12a08b4de2d48b45d5b4d594e5d32fab"
}, },
{ {
"ImportPath": "github.com/Azure/go-autorest/autorest/date", "ImportPath": "github.com/Azure/go-autorest/autorest/date",
"Rev": "e14a70c556c8e0db173358d1a903dca345a8e75e" "Rev": "d4e6b95c12a08b4de2d48b45d5b4d594e5d32fab"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/http", "ImportPath": "github.com/coreos/go-oidc/http",

View File

@ -218,6 +218,40 @@ if (err == nil) {
} }
``` ```
#### Username password authenticate
```Go
spt, err := adal.NewServicePrincipalTokenFromUsernamePassword(
oauthConfig,
applicationID,
username,
password,
resource,
callbacks...)
if (err == nil) {
token := spt.Token
}
```
#### Authorization code authenticate
``` Go
spt, err := adal.NewServicePrincipalTokenFromAuthorizationCode(
oauthConfig,
applicationID,
clientSecret,
authorizationCode,
redirectURI,
resource,
callbacks...)
err = spt.Refresh()
if (err == nil) {
token := spt.Token
}
```
### Command Line Tool ### Command Line Tool
A command line tool is available in `cmd/adal.go` that can acquire a token for a given resource. It supports all flows mentioned above. A command line tool is available in `cmd/adal.go` that can acquire a token for a given resource. It supports all flows mentioned above.

View File

@ -32,8 +32,24 @@ type OAuthConfig struct {
DeviceCodeEndpoint url.URL DeviceCodeEndpoint url.URL
} }
// IsZero returns true if the OAuthConfig object is zero-initialized.
func (oac OAuthConfig) IsZero() bool {
return oac == OAuthConfig{}
}
func validateStringParam(param, name string) error {
if len(param) == 0 {
return fmt.Errorf("parameter '" + name + "' cannot be empty")
}
return nil
}
// NewOAuthConfig returns an OAuthConfig with tenant specific urls // NewOAuthConfig returns an OAuthConfig with tenant specific urls
func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, error) { func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, error) {
if err := validateStringParam(activeDirectoryEndpoint, "activeDirectoryEndpoint"); err != nil {
return nil, err
}
// it's legal for tenantID to be empty so don't validate it
const activeDirectoryEndpointTemplate = "%s/oauth2/%s?api-version=%s" const activeDirectoryEndpointTemplate = "%s/oauth2/%s?api-version=%s"
u, err := url.Parse(activeDirectoryEndpoint) u, err := url.Parse(activeDirectoryEndpoint)
if err != nil { if err != nil {

View File

@ -27,6 +27,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/date"
@ -42,9 +43,15 @@ const (
// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
OAuthGrantTypeClientCredentials = "client_credentials" OAuthGrantTypeClientCredentials = "client_credentials"
// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
OAuthGrantTypeUserPass = "password"
// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
OAuthGrantTypeRefreshToken = "refresh_token" OAuthGrantTypeRefreshToken = "refresh_token"
// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
OAuthGrantTypeAuthorizationCode = "authorization_code"
// metadataHeader is the header required by MSI extension // metadataHeader is the header required by MSI extension
metadataHeader = "Metadata" metadataHeader = "Metadata"
) )
@ -54,6 +61,12 @@ type OAuthTokenProvider interface {
OAuthToken() string OAuthToken() string
} }
// TokenRefreshError is an interface used by errors returned during token refresh.
type TokenRefreshError interface {
error
Response() *http.Response
}
// Refresher is an interface for token refresh functionality // Refresher is an interface for token refresh functionality
type Refresher interface { type Refresher interface {
Refresh() error Refresh() error
@ -78,6 +91,11 @@ type Token struct {
Type string `json:"token_type"` Type string `json:"token_type"`
} }
// IsZero returns true if the token object is zero-initialized.
func (t Token) IsZero() bool {
return t == Token{}
}
// Expires returns the time.Time when the Token expires. // Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time { func (t Token) Expires() time.Time {
s, err := strconv.Atoi(t.ExpiresOn) s, err := strconv.Atoi(t.ExpiresOn)
@ -145,6 +163,34 @@ type ServicePrincipalCertificateSecret struct {
type ServicePrincipalMSISecret struct { type ServicePrincipalMSISecret struct {
} }
// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
Username string
Password string
}
// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
ClientSecret string
AuthorizationCode string
RedirectURI string
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("code", secret.AuthorizationCode)
v.Set("client_secret", secret.ClientSecret)
v.Set("redirect_uri", secret.RedirectURI)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("username", secret.Username)
v.Set("password", secret.Password)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil return nil
@ -199,25 +245,46 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se
type ServicePrincipalToken struct { type ServicePrincipalToken struct {
Token Token
secret ServicePrincipalSecret secret ServicePrincipalSecret
oauthConfig OAuthConfig oauthConfig OAuthConfig
clientID string clientID string
resource string resource string
autoRefresh bool autoRefresh bool
refreshWithin time.Duration autoRefreshLock *sync.Mutex
sender Sender refreshWithin time.Duration
sender Sender
refreshCallbacks []TokenRefreshCallback refreshCallbacks []TokenRefreshCallback
} }
func validateOAuthConfig(oac OAuthConfig) error {
if oac.IsZero() {
return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
}
return nil
}
// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation. // NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(id, "id"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if secret == nil {
return nil, fmt.Errorf("parameter 'secret' cannot be nil")
}
spt := &ServicePrincipalToken{ spt := &ServicePrincipalToken{
oauthConfig: oauthConfig, oauthConfig: oauthConfig,
secret: secret, secret: secret,
clientID: id, clientID: id,
resource: resource, resource: resource,
autoRefresh: true, autoRefresh: true,
autoRefreshLock: &sync.Mutex{},
refreshWithin: defaultRefresh, refreshWithin: defaultRefresh,
sender: &http.Client{}, sender: &http.Client{},
refreshCallbacks: callbacks, refreshCallbacks: callbacks,
@ -227,6 +294,18 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso
// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token // NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if token.IsZero() {
return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
}
spt, err := NewServicePrincipalTokenWithSecret( spt, err := NewServicePrincipalTokenWithSecret(
oauthConfig, oauthConfig,
clientID, clientID,
@ -245,6 +324,18 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s
// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal // NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
// credentials scoped to the named resource. // credentials scoped to the named resource.
func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(secret, "secret"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret( return NewServicePrincipalTokenWithSecret(
oauthConfig, oauthConfig,
clientID, clientID,
@ -256,8 +347,23 @@ func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret s
) )
} }
// NewServicePrincipalTokenFromCertificate create a ServicePrincipalToken from the supplied pkcs12 bytes. // NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if certificate == nil {
return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
}
if privateKey == nil {
return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
}
return NewServicePrincipalTokenWithSecret( return NewServicePrincipalTokenWithSecret(
oauthConfig, oauthConfig,
clientID, clientID,
@ -270,6 +376,70 @@ func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID s
) )
} }
// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(username, "username"); err != nil {
return nil, err
}
if err := validateStringParam(password, "password"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalUsernamePasswordSecret{
Username: username,
Password: password,
},
callbacks...,
)
}
// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
return nil, err
}
if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
return nil, err
}
if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalAuthorizationCodeSecret{
ClientSecret: clientSecret,
AuthorizationCode: authorizationCode,
RedirectURI: redirectURI,
},
callbacks...,
)
}
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines. // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) { func GetMSIVMEndpoint() (string, error) {
return getMSIVMEndpoint(msiPath) return getMSIVMEndpoint(msiPath)
@ -293,7 +463,29 @@ func getMSIVMEndpoint(path string) (string, error) {
} }
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the system assigned identity when creating the token.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
}
// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the specified user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
}
func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if userAssignedID != nil {
if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
return nil, err
}
}
// We set the oauth config token endpoint to be MSI's endpoint // We set the oauth config token endpoint to be MSI's endpoint
msiEndpointURL, err := url.Parse(msiEndpoint) msiEndpointURL, err := url.Parse(msiEndpoint)
if err != nil { if err != nil {
@ -310,19 +502,49 @@ func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...
secret: &ServicePrincipalMSISecret{}, secret: &ServicePrincipalMSISecret{},
resource: resource, resource: resource,
autoRefresh: true, autoRefresh: true,
autoRefreshLock: &sync.Mutex{},
refreshWithin: defaultRefresh, refreshWithin: defaultRefresh,
sender: &http.Client{}, sender: &http.Client{},
refreshCallbacks: callbacks, refreshCallbacks: callbacks,
} }
if userAssignedID != nil {
spt.clientID = *userAssignedID
}
return spt, nil return spt, nil
} }
// internal type that implements TokenRefreshError
type tokenRefreshError struct {
message string
resp *http.Response
}
// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre tokenRefreshError) Error() string {
return tre.message
}
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre tokenRefreshError) Response() *http.Response {
return tre.resp
}
func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
return tokenRefreshError{message: message, resp: resp}
}
// EnsureFresh will refresh the token if it will expire within the refresh window (as set by // EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFresh() error { func (spt *ServicePrincipalToken) EnsureFresh() error {
if spt.autoRefresh && spt.WillExpireIn(spt.refreshWithin) { if spt.autoRefresh && spt.WillExpireIn(spt.refreshWithin) {
return spt.Refresh() // take the lock then check to see if the token was already refreshed
spt.autoRefreshLock.Lock()
defer spt.autoRefreshLock.Unlock()
if spt.WillExpireIn(spt.refreshWithin) {
return spt.Refresh()
}
} }
return nil return nil
} }
@ -341,15 +563,28 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
} }
// Refresh obtains a fresh token for the Service Principal. // Refresh obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) Refresh() error { func (spt *ServicePrincipalToken) Refresh() error {
return spt.refreshInternal(spt.resource) return spt.refreshInternal(spt.resource)
} }
// RefreshExchange refreshes the token, but for a different resource. // RefreshExchange refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
return spt.refreshInternal(resource) return spt.refreshInternal(resource)
} }
func (spt *ServicePrincipalToken) getGrantType() string {
switch spt.secret.(type) {
case *ServicePrincipalUsernamePasswordSecret:
return OAuthGrantTypeUserPass
case *ServicePrincipalAuthorizationCodeSecret:
return OAuthGrantTypeAuthorizationCode
default:
return OAuthGrantTypeClientCredentials
}
}
func (spt *ServicePrincipalToken) refreshInternal(resource string) error { func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
v := url.Values{} v := url.Values{}
v.Set("client_id", spt.clientID) v.Set("client_id", spt.clientID)
@ -359,7 +594,7 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
v.Set("grant_type", OAuthGrantTypeRefreshToken) v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.RefreshToken) v.Set("refresh_token", spt.RefreshToken)
} else { } else {
v.Set("grant_type", OAuthGrantTypeClientCredentials) v.Set("grant_type", spt.getGrantType())
err := spt.secret.SetAuthenticationValues(spt, &v) err := spt.secret.SetAuthenticationValues(spt, &v)
if err != nil { if err != nil {
return err return err
@ -388,9 +623,9 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
if err != nil { if err != nil {
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode) return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode), resp)
} }
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)) return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
} }
if err != nil { if err != nil {

View File

@ -24,9 +24,12 @@ import (
) )
const ( const (
bearerChallengeHeader = "Www-Authenticate" bearerChallengeHeader = "Www-Authenticate"
bearer = "Bearer" bearer = "Bearer"
tenantID = "tenantID" tenantID = "tenantID"
apiKeyAuthorizerHeader = "Ocp-Apim-Subscription-Key"
bingAPISdkHeader = "X-BingApis-SDK-Client"
golangBingAPISdkHeaderValue = "Go-SDK"
) )
// Authorizer is the interface that provides a PrepareDecorator used to supply request // Authorizer is the interface that provides a PrepareDecorator used to supply request
@ -44,6 +47,53 @@ func (na NullAuthorizer) WithAuthorization() PrepareDecorator {
return WithNothing() return WithNothing()
} }
// APIKeyAuthorizer implements API Key authorization.
type APIKeyAuthorizer struct {
headers map[string]interface{}
queryParameters map[string]interface{}
}
// NewAPIKeyAuthorizerWithHeaders creates an ApiKeyAuthorizer with headers.
func NewAPIKeyAuthorizerWithHeaders(headers map[string]interface{}) *APIKeyAuthorizer {
return NewAPIKeyAuthorizer(headers, nil)
}
// NewAPIKeyAuthorizerWithQueryParameters creates an ApiKeyAuthorizer with query parameters.
func NewAPIKeyAuthorizerWithQueryParameters(queryParameters map[string]interface{}) *APIKeyAuthorizer {
return NewAPIKeyAuthorizer(nil, queryParameters)
}
// NewAPIKeyAuthorizer creates an ApiKeyAuthorizer with headers.
func NewAPIKeyAuthorizer(headers map[string]interface{}, queryParameters map[string]interface{}) *APIKeyAuthorizer {
return &APIKeyAuthorizer{headers: headers, queryParameters: queryParameters}
}
// WithAuthorization returns a PrepareDecorator that adds an HTTP headers and Query Paramaters
func (aka *APIKeyAuthorizer) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return DecoratePreparer(p, WithHeaders(aka.headers), WithQueryParameters(aka.queryParameters))
}
}
// CognitiveServicesAuthorizer implements authorization for Cognitive Services.
type CognitiveServicesAuthorizer struct {
subscriptionKey string
}
// NewCognitiveServicesAuthorizer is
func NewCognitiveServicesAuthorizer(subscriptionKey string) *CognitiveServicesAuthorizer {
return &CognitiveServicesAuthorizer{subscriptionKey: subscriptionKey}
}
// WithAuthorization is
func (csa *CognitiveServicesAuthorizer) WithAuthorization() PrepareDecorator {
headers := make(map[string]interface{})
headers[apiKeyAuthorizerHeader] = csa.subscriptionKey
headers[bingAPISdkHeader] = golangBingAPISdkHeaderValue
return NewAPIKeyAuthorizerWithHeaders(headers).WithAuthorization()
}
// BearerAuthorizer implements the bearer authorization // BearerAuthorizer implements the bearer authorization
type BearerAuthorizer struct { type BearerAuthorizer struct {
tokenProvider adal.OAuthTokenProvider tokenProvider adal.OAuthTokenProvider
@ -69,7 +119,11 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator {
if ok { if ok {
err := refresher.EnsureFresh() err := refresher.EnsureFresh()
if err != nil { if err != nil {
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", nil, var resp *http.Response
if tokError, ok := err.(adal.TokenRefreshError); ok {
resp = tokError.Response()
}
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp,
"Failed to refresh the Token for request to %s", r.URL) "Failed to refresh the Token for request to %s", r.URL)
} }
} }
@ -179,3 +233,22 @@ func newBearerChallenge(resp *http.Response) (bc bearerChallenge, err error) {
return bc, err return bc, err
} }
// EventGridKeyAuthorizer implements authorization for event grid using key authentication.
type EventGridKeyAuthorizer struct {
topicKey string
}
// NewEventGridKeyAuthorizer creates a new EventGridKeyAuthorizer
// with the specified topic key.
func NewEventGridKeyAuthorizer(topicKey string) EventGridKeyAuthorizer {
return EventGridKeyAuthorizer{topicKey: topicKey}
}
// WithAuthorization returns a PrepareDecorator that adds the aeg-sas-key authentication header.
func (egta EventGridKeyAuthorizer) WithAuthorization() PrepareDecorator {
headers := map[string]interface{}{
"aeg-sas-key": egta.topicKey,
}
return NewAPIKeyAuthorizerWithHeaders(headers).WithAuthorization()
}

View File

@ -87,6 +87,9 @@ const (
// ResponseHasStatusCode returns true if the status code in the HTTP Response is in the passed set // ResponseHasStatusCode returns true if the status code in the HTTP Response is in the passed set
// and false otherwise. // and false otherwise.
func ResponseHasStatusCode(resp *http.Response, codes ...int) bool { func ResponseHasStatusCode(resp *http.Response, codes ...int) bool {
if resp == nil {
return false
}
return containsInt(codes, resp.StatusCode) return containsInt(codes, resp.StatusCode)
} }

View File

@ -16,6 +16,8 @@ package azure
import ( import (
"bytes" "bytes"
"context"
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -37,6 +39,152 @@ const (
operationSucceeded string = "Succeeded" operationSucceeded string = "Succeeded"
) )
var pollingCodes = [...]int{http.StatusNoContent, http.StatusAccepted, http.StatusCreated, http.StatusOK}
// Future provides a mechanism to access the status and results of an asynchronous request.
// Since futures are stateful they should be passed by value to avoid race conditions.
type Future struct {
req *http.Request
resp *http.Response
ps pollingState
}
// NewFuture returns a new Future object initialized with the specified request.
func NewFuture(req *http.Request) Future {
return Future{req: req}
}
// Response returns the last HTTP response or nil if there isn't one.
func (f Future) Response() *http.Response {
return f.resp
}
// Status returns the last status message of the operation.
func (f Future) Status() string {
if f.ps.State == "" {
return "Unknown"
}
return f.ps.State
}
// PollingMethod returns the method used to monitor the status of the asynchronous operation.
func (f Future) PollingMethod() PollingMethodType {
return f.ps.PollingMethod
}
// Done queries the service to see if the operation has completed.
func (f *Future) Done(sender autorest.Sender) (bool, error) {
// exit early if this future has terminated
if f.ps.hasTerminated() {
return true, f.errorInfo()
}
resp, err := sender.Do(f.req)
f.resp = resp
if err != nil || !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) {
return false, err
}
err = updatePollingState(resp, &f.ps)
if err != nil {
return false, err
}
if f.ps.hasTerminated() {
return true, f.errorInfo()
}
f.req, err = newPollingRequest(f.ps)
return false, err
}
// GetPollingDelay returns a duration the application should wait before checking
// the status of the asynchronous request and true; this value is returned from
// the service via the Retry-After response header. If the header wasn't returned
// then the function returns the zero-value time.Duration and false.
func (f Future) GetPollingDelay() (time.Duration, bool) {
if f.resp == nil {
return 0, false
}
retry := f.resp.Header.Get(autorest.HeaderRetryAfter)
if retry == "" {
return 0, false
}
d, err := time.ParseDuration(retry + "s")
if err != nil {
panic(err)
}
return d, true
}
// WaitForCompletion will return when one of the following conditions is met: the long
// running operation has completed, the provided context is cancelled, or the client's
// polling duration has been exceeded. It will retry failed polling attempts based on
// the retry value defined in the client up to the maximum retry attempts.
func (f Future) WaitForCompletion(ctx context.Context, client autorest.Client) error {
ctx, cancel := context.WithTimeout(ctx, client.PollingDuration)
defer cancel()
done, err := f.Done(client)
for attempts := 0; !done; done, err = f.Done(client) {
if attempts >= client.RetryAttempts {
return autorest.NewErrorWithError(err, "azure", "WaitForCompletion", f.resp, "the number of retries has been exceeded")
}
// we want delayAttempt to be zero in the non-error case so
// that DelayForBackoff doesn't perform exponential back-off
var delayAttempt int
var delay time.Duration
if err == nil {
// check for Retry-After delay, if not present use the client's polling delay
var ok bool
delay, ok = f.GetPollingDelay()
if !ok {
delay = client.PollingDelay
}
} else {
// there was an error polling for status so perform exponential
// back-off based on the number of attempts using the client's retry
// duration. update attempts after delayAttempt to avoid off-by-one.
delayAttempt = attempts
delay = client.RetryDuration
attempts++
}
// wait until the delay elapses or the context is cancelled
delayElapsed := autorest.DelayForBackoff(delay, delayAttempt, ctx.Done())
if !delayElapsed {
return autorest.NewErrorWithError(ctx.Err(), "azure", "WaitForCompletion", f.resp, "context has been cancelled")
}
}
return err
}
// if the operation failed the polling state will contain
// error information and implements the error interface
func (f *Future) errorInfo() error {
if !f.ps.hasSucceeded() {
return f.ps
}
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (f Future) MarshalJSON() ([]byte, error) {
return json.Marshal(&f.ps)
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (f *Future) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, &f.ps)
if err != nil {
return err
}
f.req, err = newPollingRequest(f.ps)
return err
}
// DoPollForAsynchronous returns a SendDecorator that polls if the http.Response is for an Azure // DoPollForAsynchronous returns a SendDecorator that polls if the http.Response is for an Azure
// long-running operation. It will delay between requests for the duration specified in the // long-running operation. It will delay between requests for the duration specified in the
// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by // RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by
@ -48,8 +196,7 @@ func DoPollForAsynchronous(delay time.Duration) autorest.SendDecorator {
if err != nil { if err != nil {
return resp, err return resp, err
} }
pollingCodes := []int{http.StatusAccepted, http.StatusCreated, http.StatusOK} if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) {
if !autorest.ResponseHasStatusCode(resp, pollingCodes...) {
return resp, nil return resp, nil
} }
@ -66,10 +213,11 @@ func DoPollForAsynchronous(delay time.Duration) autorest.SendDecorator {
break break
} }
r, err = newPollingRequest(resp, ps) r, err = newPollingRequest(ps)
if err != nil { if err != nil {
return resp, err return resp, err
} }
r.Cancel = resp.Request.Cancel
delay = autorest.GetRetryAfter(resp, delay) delay = autorest.GetRetryAfter(resp, delay)
resp, err = autorest.SendWithSender(s, r, resp, err = autorest.SendWithSender(s, r,
@ -86,20 +234,15 @@ func getAsyncOperation(resp *http.Response) string {
} }
func hasSucceeded(state string) bool { func hasSucceeded(state string) bool {
return state == operationSucceeded return strings.EqualFold(state, operationSucceeded)
} }
func hasTerminated(state string) bool { func hasTerminated(state string) bool {
switch state { return strings.EqualFold(state, operationCanceled) || strings.EqualFold(state, operationFailed) || strings.EqualFold(state, operationSucceeded)
case operationCanceled, operationFailed, operationSucceeded:
return true
default:
return false
}
} }
func hasFailed(state string) bool { func hasFailed(state string) bool {
return state == operationFailed return strings.EqualFold(state, operationFailed)
} }
type provisioningTracker interface { type provisioningTracker interface {
@ -160,36 +303,42 @@ func (ps provisioningStatus) hasProvisioningError() bool {
return ps.ProvisioningError != ServiceError{} return ps.ProvisioningError != ServiceError{}
} }
type pollingResponseFormat string // PollingMethodType defines a type used for enumerating polling mechanisms.
type PollingMethodType string
const ( const (
usesOperationResponse pollingResponseFormat = "OperationResponse" // PollingAsyncOperation indicates the polling method uses the Azure-AsyncOperation header.
usesProvisioningStatus pollingResponseFormat = "ProvisioningStatus" PollingAsyncOperation PollingMethodType = "AsyncOperation"
formatIsUnknown pollingResponseFormat = ""
// PollingLocation indicates the polling method uses the Location header.
PollingLocation PollingMethodType = "Location"
// PollingUnknown indicates an unknown polling method and is the default value.
PollingUnknown PollingMethodType = ""
) )
type pollingState struct { type pollingState struct {
responseFormat pollingResponseFormat PollingMethod PollingMethodType `json:"pollingMethod"`
uri string URI string `json:"uri"`
state string State string `json:"state"`
code string Code string `json:"code"`
message string Message string `json:"message"`
} }
func (ps pollingState) hasSucceeded() bool { func (ps pollingState) hasSucceeded() bool {
return hasSucceeded(ps.state) return hasSucceeded(ps.State)
} }
func (ps pollingState) hasTerminated() bool { func (ps pollingState) hasTerminated() bool {
return hasTerminated(ps.state) return hasTerminated(ps.State)
} }
func (ps pollingState) hasFailed() bool { func (ps pollingState) hasFailed() bool {
return hasFailed(ps.state) return hasFailed(ps.State)
} }
func (ps pollingState) Error() string { func (ps pollingState) Error() string {
return fmt.Sprintf("Long running operation terminated with status '%s': Code=%q Message=%q", ps.state, ps.code, ps.message) return fmt.Sprintf("Long running operation terminated with status '%s': Code=%q Message=%q", ps.State, ps.Code, ps.Message)
} }
// updatePollingState maps the operation status -- retrieved from either a provisioningState // updatePollingState maps the operation status -- retrieved from either a provisioningState
@ -204,7 +353,7 @@ func updatePollingState(resp *http.Response, ps *pollingState) error {
// -- The first response will always be a provisioningStatus response; only the polling requests, // -- The first response will always be a provisioningStatus response; only the polling requests,
// depending on the header returned, may be something otherwise. // depending on the header returned, may be something otherwise.
var pt provisioningTracker var pt provisioningTracker
if ps.responseFormat == usesOperationResponse { if ps.PollingMethod == PollingAsyncOperation {
pt = &operationResource{} pt = &operationResource{}
} else { } else {
pt = &provisioningStatus{} pt = &provisioningStatus{}
@ -212,30 +361,30 @@ func updatePollingState(resp *http.Response, ps *pollingState) error {
// If this is the first request (that is, the polling response shape is unknown), determine how // If this is the first request (that is, the polling response shape is unknown), determine how
// to poll and what to expect // to poll and what to expect
if ps.responseFormat == formatIsUnknown { if ps.PollingMethod == PollingUnknown {
req := resp.Request req := resp.Request
if req == nil { if req == nil {
return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Original HTTP request is missing") return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Original HTTP request is missing")
} }
// Prefer the Azure-AsyncOperation header // Prefer the Azure-AsyncOperation header
ps.uri = getAsyncOperation(resp) ps.URI = getAsyncOperation(resp)
if ps.uri != "" { if ps.URI != "" {
ps.responseFormat = usesOperationResponse ps.PollingMethod = PollingAsyncOperation
} else { } else {
ps.responseFormat = usesProvisioningStatus ps.PollingMethod = PollingLocation
} }
// Else, use the Location header // Else, use the Location header
if ps.uri == "" { if ps.URI == "" {
ps.uri = autorest.GetLocation(resp) ps.URI = autorest.GetLocation(resp)
} }
// Lastly, requests against an existing resource, use the last request URI // Lastly, requests against an existing resource, use the last request URI
if ps.uri == "" { if ps.URI == "" {
m := strings.ToUpper(req.Method) m := strings.ToUpper(req.Method)
if m == http.MethodPatch || m == http.MethodPut || m == http.MethodGet { if m == http.MethodPatch || m == http.MethodPut || m == http.MethodGet {
ps.uri = req.URL.String() ps.URI = req.URL.String()
} }
} }
} }
@ -256,23 +405,23 @@ func updatePollingState(resp *http.Response, ps *pollingState) error {
// -- Unknown states are per-service inprogress states // -- Unknown states are per-service inprogress states
// -- Otherwise, infer state from HTTP status code // -- Otherwise, infer state from HTTP status code
if pt.hasTerminated() { if pt.hasTerminated() {
ps.state = pt.state() ps.State = pt.state()
} else if pt.state() != "" { } else if pt.state() != "" {
ps.state = operationInProgress ps.State = operationInProgress
} else { } else {
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusAccepted: case http.StatusAccepted:
ps.state = operationInProgress ps.State = operationInProgress
case http.StatusNoContent, http.StatusCreated, http.StatusOK: case http.StatusNoContent, http.StatusCreated, http.StatusOK:
ps.state = operationSucceeded ps.State = operationSucceeded
default: default:
ps.state = operationFailed ps.State = operationFailed
} }
} }
if ps.state == operationInProgress && ps.uri == "" { if strings.EqualFold(ps.State, operationInProgress) && ps.URI == "" {
return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Unable to obtain polling URI for %s %s", resp.Request.Method, resp.Request.URL) return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Unable to obtain polling URI for %s %s", resp.Request.Method, resp.Request.URL)
} }
@ -281,36 +430,49 @@ func updatePollingState(resp *http.Response, ps *pollingState) error {
// -- Response // -- Response
// -- Otherwise, Unknown // -- Otherwise, Unknown
if ps.hasFailed() { if ps.hasFailed() {
if ps.responseFormat == usesOperationResponse { if ps.PollingMethod == PollingAsyncOperation {
or := pt.(*operationResource) or := pt.(*operationResource)
ps.code = or.OperationError.Code ps.Code = or.OperationError.Code
ps.message = or.OperationError.Message ps.Message = or.OperationError.Message
} else { } else {
p := pt.(*provisioningStatus) p := pt.(*provisioningStatus)
if p.hasProvisioningError() { if p.hasProvisioningError() {
ps.code = p.ProvisioningError.Code ps.Code = p.ProvisioningError.Code
ps.message = p.ProvisioningError.Message ps.Message = p.ProvisioningError.Message
} else { } else {
ps.code = "Unknown" ps.Code = "Unknown"
ps.message = "None" ps.Message = "None"
} }
} }
} }
return nil return nil
} }
func newPollingRequest(resp *http.Response, ps pollingState) (*http.Request, error) { func newPollingRequest(ps pollingState) (*http.Request, error) {
req := resp.Request reqPoll, err := autorest.Prepare(&http.Request{},
if req == nil {
return nil, autorest.NewError("azure", "newPollingRequest", "Azure Polling Error - Original HTTP request is missing")
}
reqPoll, err := autorest.Prepare(&http.Request{Cancel: req.Cancel},
autorest.AsGet(), autorest.AsGet(),
autorest.WithBaseURL(ps.uri)) autorest.WithBaseURL(ps.URI))
if err != nil { if err != nil {
return nil, autorest.NewErrorWithError(err, "azure", "newPollingRequest", nil, "Failure creating poll request to %s", ps.uri) return nil, autorest.NewErrorWithError(err, "azure", "newPollingRequest", nil, "Failure creating poll request to %s", ps.URI)
} }
return reqPoll, nil return reqPoll, nil
} }
// AsyncOpIncompleteError is the type that's returned from a future that has not completed.
type AsyncOpIncompleteError struct {
// FutureType is the name of the type composed of a azure.Future.
FutureType string
}
// Error returns an error message including the originating type name of the error.
func (e AsyncOpIncompleteError) Error() string {
return fmt.Sprintf("%s: asynchronous operation has not completed", e.FutureType)
}
// NewAsyncOpIncompleteError creates a new AsyncOpIncompleteError with the specified parameters.
func NewAsyncOpIncompleteError(futureType string) AsyncOpIncompleteError {
return AsyncOpIncompleteError{
FutureType: futureType,
}
}

View File

@ -15,10 +15,17 @@ package azure
// limitations under the License. // limitations under the License.
import ( import (
"encoding/json"
"fmt" "fmt"
"io/ioutil"
"os"
"strings" "strings"
) )
// EnvironmentFilepathName captures the name of the environment variable containing the path to the file
// to be used while populating the Azure Environment.
const EnvironmentFilepathName = "AZURE_ENVIRONMENT_FILEPATH"
var environments = map[string]Environment{ var environments = map[string]Environment{
"AZURECHINACLOUD": ChinaCloud, "AZURECHINACLOUD": ChinaCloud,
"AZUREGERMANCLOUD": GermanCloud, "AZUREGERMANCLOUD": GermanCloud,
@ -76,10 +83,10 @@ var (
PublishSettingsURL: "https://manage.windowsazure.us/publishsettings/index", PublishSettingsURL: "https://manage.windowsazure.us/publishsettings/index",
ServiceManagementEndpoint: "https://management.core.usgovcloudapi.net/", ServiceManagementEndpoint: "https://management.core.usgovcloudapi.net/",
ResourceManagerEndpoint: "https://management.usgovcloudapi.net/", ResourceManagerEndpoint: "https://management.usgovcloudapi.net/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com/", ActiveDirectoryEndpoint: "https://login.microsoftonline.us/",
GalleryEndpoint: "https://gallery.usgovcloudapi.net/", GalleryEndpoint: "https://gallery.usgovcloudapi.net/",
KeyVaultEndpoint: "https://vault.usgovcloudapi.net/", KeyVaultEndpoint: "https://vault.usgovcloudapi.net/",
GraphEndpoint: "https://graph.usgovcloudapi.net/", GraphEndpoint: "https://graph.windows.net/",
StorageEndpointSuffix: "core.usgovcloudapi.net", StorageEndpointSuffix: "core.usgovcloudapi.net",
SQLDatabaseDNSSuffix: "database.usgovcloudapi.net", SQLDatabaseDNSSuffix: "database.usgovcloudapi.net",
TrafficManagerDNSSuffix: "usgovtrafficmanager.net", TrafficManagerDNSSuffix: "usgovtrafficmanager.net",
@ -133,12 +140,37 @@ var (
} }
) )
// EnvironmentFromName returns an Environment based on the common name specified // EnvironmentFromName returns an Environment based on the common name specified.
func EnvironmentFromName(name string) (Environment, error) { func EnvironmentFromName(name string) (Environment, error) {
// IMPORTANT
// As per @radhikagupta5:
// This is technical debt, fundamentally here because Kubernetes is not currently accepting
// contributions to the providers. Once that is an option, the provider should be updated to
// directly call `EnvironmentFromFile`. Until then, we rely on dispatching Azure Stack environment creation
// from this method based on the name that is provided to us.
if strings.EqualFold(name, "AZURESTACKCLOUD") {
return EnvironmentFromFile(os.Getenv(EnvironmentFilepathName))
}
name = strings.ToUpper(name) name = strings.ToUpper(name)
env, ok := environments[name] env, ok := environments[name]
if !ok { if !ok {
return env, fmt.Errorf("autorest/azure: There is no cloud environment matching the name %q", name) return env, fmt.Errorf("autorest/azure: There is no cloud environment matching the name %q", name)
} }
return env, nil return env, nil
} }
// EnvironmentFromFile loads an Environment from a configuration file available on disk.
// This function is particularly useful in the Hybrid Cloud model, where one must define their own
// endpoints.
func EnvironmentFromFile(location string) (unmarshaled Environment, err error) {
fileContents, err := ioutil.ReadFile(location)
if err != nil {
return
}
err = json.Unmarshal(fileContents, &unmarshaled)
return
}

View File

@ -1,3 +1,17 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package azure package azure
import ( import (
@ -30,7 +44,7 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
return resp, err return resp, err
} }
if resp.StatusCode != http.StatusConflict { if resp.StatusCode != http.StatusConflict || client.SkipResourceProviderRegistration {
return resp, err return resp, err
} }
var re RequestError var re RequestError
@ -41,15 +55,16 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
if err != nil { if err != nil {
return resp, err return resp, err
} }
err = re
if re.ServiceError != nil && re.ServiceError.Code == "MissingSubscriptionRegistration" { if re.ServiceError != nil && re.ServiceError.Code == "MissingSubscriptionRegistration" {
err = register(client, r, re) regErr := register(client, r, re)
if err != nil { if regErr != nil {
return resp, fmt.Errorf("failed auto registering Resource Provider: %s", err) return resp, fmt.Errorf("failed auto registering Resource Provider: %s. Original error: %s", regErr, err)
} }
} }
} }
return resp, errors.New("failed request and resource provider registration") return resp, fmt.Errorf("failed request: %s", err)
}) })
} }
} }
@ -144,7 +159,7 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
} }
req.Cancel = originalReq.Cancel req.Cancel = originalReq.Cancel
resp, err := autorest.SendWithSender(client.Sender, req, resp, err := autorest.SendWithSender(client, req,
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
) )
if err != nil { if err != nil {

View File

@ -35,6 +35,9 @@ const (
// DefaultRetryAttempts is number of attempts for retry status codes (5xx). // DefaultRetryAttempts is number of attempts for retry status codes (5xx).
DefaultRetryAttempts = 3 DefaultRetryAttempts = 3
// DefaultRetryDuration is the duration to wait between retries.
DefaultRetryDuration = 30 * time.Second
) )
var ( var (
@ -163,6 +166,9 @@ type Client struct {
UserAgent string UserAgent string
Jar http.CookieJar Jar http.CookieJar
// Set to true to skip attempted registration of resource providers (false by default).
SkipResourceProviderRegistration bool
} }
// NewClientWithUserAgent returns an instance of a Client with the UserAgent set to the passed // NewClientWithUserAgent returns an instance of a Client with the UserAgent set to the passed
@ -172,9 +178,10 @@ func NewClientWithUserAgent(ua string) Client {
PollingDelay: DefaultPollingDelay, PollingDelay: DefaultPollingDelay,
PollingDuration: DefaultPollingDuration, PollingDuration: DefaultPollingDuration,
RetryAttempts: DefaultRetryAttempts, RetryAttempts: DefaultRetryAttempts,
RetryDuration: 30 * time.Second, RetryDuration: DefaultRetryDuration,
UserAgent: defaultUserAgent, UserAgent: defaultUserAgent,
} }
c.Sender = c.sender()
c.AddToUserAgent(ua) c.AddToUserAgent(ua)
return c return c
} }
@ -200,11 +207,17 @@ func (c Client) Do(r *http.Request) (*http.Response, error) {
c.WithInspection(), c.WithInspection(),
c.WithAuthorization()) c.WithAuthorization())
if err != nil { if err != nil {
return nil, NewErrorWithError(err, "autorest/Client", "Do", nil, "Preparing request failed") var resp *http.Response
if detErr, ok := err.(DetailedError); ok {
// if the authorization failed (e.g. invalid credentials) there will
// be a response associated with the error, be sure to return it.
resp = detErr.Response
}
return resp, NewErrorWithError(err, "autorest/Client", "Do", nil, "Preparing request failed")
} }
resp, err := SendWithSender(c.sender(), r) resp, err := SendWithSender(c.sender(), r)
Respond(resp, Respond(resp, c.ByInspecting())
c.ByInspecting())
return resp, err return resp, err
} }

View File

@ -27,8 +27,9 @@ import (
) )
const ( const (
mimeTypeJSON = "application/json" mimeTypeJSON = "application/json"
mimeTypeFormPost = "application/x-www-form-urlencoded" mimeTypeOctetStream = "application/octet-stream"
mimeTypeFormPost = "application/x-www-form-urlencoded"
headerAuthorization = "Authorization" headerAuthorization = "Authorization"
headerContentType = "Content-Type" headerContentType = "Content-Type"
@ -112,6 +113,28 @@ func WithHeader(header string, value string) PrepareDecorator {
} }
} }
// WithHeaders returns a PrepareDecorator that sets the specified HTTP headers of the http.Request to
// the passed value. It canonicalizes the passed headers name (via http.CanonicalHeaderKey) before
// adding them.
func WithHeaders(headers map[string]interface{}) PrepareDecorator {
h := ensureValueStrings(headers)
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.Header == nil {
r.Header = make(http.Header)
}
for name, value := range h {
r.Header.Set(http.CanonicalHeaderKey(name), value)
}
}
return r, err
})
}
}
// WithBearerAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose // WithBearerAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose
// value is "Bearer " followed by the supplied token. // value is "Bearer " followed by the supplied token.
func WithBearerAuthorization(token string) PrepareDecorator { func WithBearerAuthorization(token string) PrepareDecorator {
@ -142,6 +165,11 @@ func AsJSON() PrepareDecorator {
return AsContentType(mimeTypeJSON) return AsContentType(mimeTypeJSON)
} }
// AsOctetStream returns a PrepareDecorator that adds the "application/octet-stream" Content-Type header.
func AsOctetStream() PrepareDecorator {
return AsContentType(mimeTypeOctetStream)
}
// WithMethod returns a PrepareDecorator that sets the HTTP method of the passed request. The // WithMethod returns a PrepareDecorator that sets the HTTP method of the passed request. The
// decorator does not validate that the passed method string is a known HTTP method. // decorator does not validate that the passed method string is a known HTTP method.
func WithMethod(method string) PrepareDecorator { func WithMethod(method string) PrepareDecorator {
@ -215,6 +243,11 @@ func WithFormData(v url.Values) PrepareDecorator {
r, err := p.Prepare(r) r, err := p.Prepare(r)
if err == nil { if err == nil {
s := v.Encode() s := v.Encode()
if r.Header == nil {
r.Header = make(http.Header)
}
r.Header.Set(http.CanonicalHeaderKey(headerContentType), mimeTypeFormPost)
r.ContentLength = int64(len(s)) r.ContentLength = int64(len(s))
r.Body = ioutil.NopCloser(strings.NewReader(s)) r.Body = ioutil.NopCloser(strings.NewReader(s))
} }
@ -430,11 +463,16 @@ func WithQueryParameters(queryParameters map[string]interface{}) PrepareDecorato
if r.URL == nil { if r.URL == nil {
return r, NewError("autorest", "WithQueryParameters", "Invoked with a nil URL") return r, NewError("autorest", "WithQueryParameters", "Invoked with a nil URL")
} }
v := r.URL.Query() v := r.URL.Query()
for key, value := range parameters { for key, value := range parameters {
v.Add(key, value) d, err := url.QueryUnescape(value)
if err != nil {
return r, err
}
v.Add(key, d)
} }
r.URL.RawQuery = createQuery(v) r.URL.RawQuery = v.Encode()
} }
return r, err return r, err
}) })

View File

@ -215,19 +215,26 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se
rr := NewRetriableRequest(r) rr := NewRetriableRequest(r)
// Increment to add the first call (attempts denotes number of retries) // Increment to add the first call (attempts denotes number of retries)
attempts++ attempts++
for attempt := 0; attempt < attempts; attempt++ { for attempt := 0; attempt < attempts; {
err = rr.Prepare() err = rr.Prepare()
if err != nil { if err != nil {
return resp, err return resp, err
} }
resp, err = s.Do(rr.Request()) resp, err = s.Do(rr.Request())
if err != nil || !ResponseHasStatusCode(resp, codes...) { // we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication
// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
return resp, err return resp, err
} }
delayed := DelayWithRetryAfter(resp, r.Cancel) delayed := DelayWithRetryAfter(resp, r.Cancel)
if !delayed { if !delayed {
DelayForBackoff(backoff, attempt, r.Cancel) DelayForBackoff(backoff, attempt, r.Cancel)
} }
// don't count a 429 against the number of attempts
// so that we continue to retry until it succeeds
if resp == nil || resp.StatusCode != http.StatusTooManyRequests {
attempt++
}
} }
return resp, err return resp, err
}) })
@ -237,6 +244,9 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se
// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header in // DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header in
// responses with status code 429 // responses with status code 429
func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool { func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool {
if resp == nil {
return false
}
retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After")) retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After"))
if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 { if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 {
select { select {

View File

@ -20,10 +20,12 @@ import (
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io" "io"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"sort"
"strings" "strings"
"github.com/Azure/go-autorest/autorest/adal"
) )
// EncodedAs is a series of constants specifying various data encodings // EncodedAs is a series of constants specifying various data encodings
@ -137,13 +139,38 @@ func MapToValues(m map[string]interface{}) url.Values {
return v return v
} }
// String method converts interface v to string. If interface is a list, it // AsStringSlice method converts interface{} to []string. This expects a
// joins list elements using separator. //that the parameter passed to be a slice or array of a type that has the underlying
func String(v interface{}, sep ...string) string { //type a string.
if len(sep) > 0 { func AsStringSlice(s interface{}) ([]string, error) {
return ensureValueString(strings.Join(v.([]string), sep[0])) v := reflect.ValueOf(s)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, NewError("autorest", "AsStringSlice", "the value's type is not an array.")
} }
return ensureValueString(v) stringSlice := make([]string, 0, v.Len())
for i := 0; i < v.Len(); i++ {
stringSlice = append(stringSlice, v.Index(i).String())
}
return stringSlice, nil
}
// String method converts interface v to string. If interface is a list, it
// joins list elements using the seperator. Note that only sep[0] will be used for
// joining if any separator is specified.
func String(v interface{}, sep ...string) string {
if len(sep) == 0 {
return ensureValueString(v)
}
stringSlice, ok := v.([]string)
if ok == false {
var err error
stringSlice, err = AsStringSlice(v)
if err != nil {
panic(fmt.Sprintf("autorest: Couldn't convert value to a string %s.", err))
}
}
return ensureValueString(strings.Join(stringSlice, sep[0]))
} }
// Encode method encodes url path and query parameters. // Encode method encodes url path and query parameters.
@ -167,26 +194,25 @@ func queryEscape(s string) string {
return url.QueryEscape(s) return url.QueryEscape(s)
} }
// This method is same as Encode() method of "net/url" go package, // ChangeToGet turns the specified http.Request into a GET (it assumes it wasn't).
// except it does not encode the query parameters because they // This is mainly useful for long-running operations that use the Azure-AsyncOperation
// already come encoded. It formats values map in query format (bar=foo&a=b). // header, so we change the initial PUT into a GET to retrieve the final result.
func createQuery(v url.Values) string { func ChangeToGet(req *http.Request) *http.Request {
var buf bytes.Buffer req.Method = "GET"
keys := make([]string, 0, len(v)) req.Body = nil
for k := range v { req.ContentLength = 0
keys = append(keys, k) req.Header.Del("Content-Length")
} return req
sort.Strings(keys) }
for _, k := range keys {
vs := v[k] // IsTokenRefreshError returns true if the specified error implements the TokenRefreshError
prefix := url.QueryEscape(k) + "=" // interface. If err is a DetailedError it will walk the chain of Original errors.
for _, v := range vs { func IsTokenRefreshError(err error) bool {
if buf.Len() > 0 { if _, ok := err.(adal.TokenRefreshError); ok {
buf.WriteByte('&') return true
} }
buf.WriteString(prefix) if de, ok := err.(DetailedError); ok {
buf.WriteString(v) return IsTokenRefreshError(de.Original)
} }
} return false
return buf.String()
} }

View File

@ -22,9 +22,9 @@ import (
) )
const ( const (
major = 8 major = 9
minor = 0 minor = 8
patch = 0 patch = 1
tag = "" tag = ""
) )