diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index bca54744..7ea38710 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -1,6 +1,6 @@ { "ImportPath": "k8s.io/client-go", - "GoVersion": "go1.9", + "GoVersion": "go1.10", "GodepVersion": "v80", "Packages": [ "./..." @@ -16,19 +16,23 @@ }, { "ImportPath": "github.com/Azure/go-autorest/autorest", - "Rev": "1ff28809256a84bb6966640ff3d0371af82ccba4" + "Rev": "bca49d5b51a50dc5bb17bbf6204c711c6dbded06" }, { "ImportPath": "github.com/Azure/go-autorest/autorest/adal", - "Rev": "1ff28809256a84bb6966640ff3d0371af82ccba4" + "Rev": "bca49d5b51a50dc5bb17bbf6204c711c6dbded06" }, { "ImportPath": "github.com/Azure/go-autorest/autorest/azure", - "Rev": "1ff28809256a84bb6966640ff3d0371af82ccba4" + "Rev": "bca49d5b51a50dc5bb17bbf6204c711c6dbded06" }, { "ImportPath": "github.com/Azure/go-autorest/autorest/date", - "Rev": "1ff28809256a84bb6966640ff3d0371af82ccba4" + "Rev": "bca49d5b51a50dc5bb17bbf6204c711c6dbded06" + }, + { + "ImportPath": "github.com/Azure/go-autorest/version", + "Rev": "bca49d5b51a50dc5bb17bbf6204c711c6dbded06" }, { "ImportPath": "github.com/davecgh/go-spew/spew", diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/config.go b/vendor/github.com/Azure/go-autorest/autorest/adal/config.go index f570d540..bee5e61d 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/config.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/config.go @@ -26,10 +26,10 @@ const ( // OAuthConfig represents the endpoints needed // in OAuth operations type OAuthConfig struct { - AuthorityEndpoint url.URL - AuthorizeEndpoint url.URL - TokenEndpoint url.URL - DeviceCodeEndpoint url.URL + AuthorityEndpoint url.URL `json:"authorityEndpoint"` + AuthorizeEndpoint url.URL `json:"authorizeEndpoint"` + TokenEndpoint url.URL `json:"tokenEndpoint"` + DeviceCodeEndpoint url.URL `json:"deviceCodeEndpoint"` } // IsZero returns true if the OAuthConfig object is zero-initialized. diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/token.go b/vendor/github.com/Azure/go-autorest/autorest/adal/token.go index 24641b62..32aea838 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/token.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/token.go @@ -15,14 +15,18 @@ package adal // limitations under the License. import ( + "context" "crypto/rand" "crypto/rsa" "crypto/sha1" "crypto/x509" "encoding/base64" "encoding/json" + "errors" "fmt" "io/ioutil" + "math" + "net" "net/http" "net/url" "strconv" @@ -31,6 +35,7 @@ import ( "time" "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/version" "github.com/dgrijalva/jwt-go" ) @@ -57,6 +62,9 @@ const ( // msiEndpoint is the well known endpoint for getting MSI authentications tokens msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" + + // the default number of attempts to refresh an MSI authentication token + defaultMaxMSIRefreshAttempts = 5 ) // OAuthTokenProvider is an interface which should be implemented by an access token retriever @@ -77,6 +85,13 @@ type Refresher interface { EnsureFresh() error } +// RefresherWithContext is an interface for token refresh functionality +type RefresherWithContext interface { + RefreshWithContext(ctx context.Context) error + RefreshExchangeWithContext(ctx context.Context, resource string) error + EnsureFreshWithContext(ctx context.Context) error +} + // TokenRefreshCallback is the type representing callbacks that will be called after // a successful token refresh type TokenRefreshCallback func(Token) error @@ -127,6 +142,12 @@ func (t *Token) OAuthToken() string { return t.AccessToken } +// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form +// that is submitted when acquiring an oAuth token. +type ServicePrincipalSecret interface { + SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error +} + // ServicePrincipalNoSecret represents a secret type that contains no secret // meaning it is not valid for fetching a fresh token. This is used by Manual type ServicePrincipalNoSecret struct { @@ -138,15 +159,19 @@ func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePr return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token") } -// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form -// that is submitted when acquiring an oAuth token. -type ServicePrincipalSecret interface { - SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error +// MarshalJSON implements the json.Marshaler interface. +func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalNoSecret", + }) } // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization. type ServicePrincipalTokenSecret struct { - ClientSecret string + ClientSecret string `json:"value"` } // SetAuthenticationValues is a method of the interface ServicePrincipalSecret. @@ -156,49 +181,24 @@ func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *Ser return nil } +// MarshalJSON implements the json.Marshaler interface. +func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Value string `json:"value"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalTokenSecret", + Value: tokenSecret.ClientSecret, + }) +} + // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs. type ServicePrincipalCertificateSecret struct { Certificate *x509.Certificate PrivateKey *rsa.PrivateKey } -// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. -type ServicePrincipalMSISecret struct { -} - -// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. -type ServicePrincipalUsernamePasswordSecret struct { - Username string - Password string -} - -// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. -type ServicePrincipalAuthorizationCodeSecret struct { - ClientSecret string - AuthorizationCode string - RedirectURI string -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - v.Set("code", secret.AuthorizationCode) - v.Set("client_secret", secret.ClientSecret) - v.Set("redirect_uri", secret.RedirectURI) - return nil -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - v.Set("username", secret.Username) - v.Set("password", secret.Password) - return nil -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - return nil -} - // SignJwt returns the JWT signed with the certificate's private key. func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) { hasher := sha1.New() @@ -219,9 +219,9 @@ func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalTo token := jwt.New(jwt.SigningMethodRS256) token.Header["x5t"] = thumbprint token.Claims = jwt.MapClaims{ - "aud": spt.oauthConfig.TokenEndpoint.String(), - "iss": spt.clientID, - "sub": spt.clientID, + "aud": spt.inner.OauthConfig.TokenEndpoint.String(), + "iss": spt.inner.ClientID, + "sub": spt.inner.ClientID, "jti": base64.URLEncoding.EncodeToString(jti), "nbf": time.Now().Unix(), "exp": time.Now().Add(time.Hour * 24).Unix(), @@ -244,19 +244,151 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se return nil } +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported") +} + +// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. +type ServicePrincipalMSISecret struct { +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported") +} + +// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. +type ServicePrincipalUsernamePasswordSecret struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + v.Set("username", secret.Username) + v.Set("password", secret.Password) + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Username string `json:"username"` + Password string `json:"password"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalUsernamePasswordSecret", + Username: secret.Username, + Password: secret.Password, + }) +} + +// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. +type ServicePrincipalAuthorizationCodeSecret struct { + ClientSecret string `json:"value"` + AuthorizationCode string `json:"authCode"` + RedirectURI string `json:"redirect"` +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + v.Set("code", secret.AuthorizationCode) + v.Set("client_secret", secret.ClientSecret) + v.Set("redirect_uri", secret.RedirectURI) + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Value string `json:"value"` + AuthCode string `json:"authCode"` + Redirect string `json:"redirect"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalAuthorizationCodeSecret", + Value: secret.ClientSecret, + AuthCode: secret.AuthorizationCode, + Redirect: secret.RedirectURI, + }) +} + // ServicePrincipalToken encapsulates a Token created for a Service Principal. type ServicePrincipalToken struct { - token Token - secret ServicePrincipalSecret - oauthConfig OAuthConfig - clientID string - resource string - autoRefresh bool - refreshLock *sync.RWMutex - refreshWithin time.Duration - sender Sender - + inner servicePrincipalToken + refreshLock *sync.RWMutex + sender Sender refreshCallbacks []TokenRefreshCallback + // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token. + MaxMSIRefreshAttempts int +} + +// MarshalTokenJSON returns the marshalled inner token. +func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) { + return json.Marshal(spt.inner.Token) +} + +// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks. +func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) { + spt.refreshCallbacks = callbacks +} + +// MarshalJSON implements the json.Marshaler interface. +func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) { + return json.Marshal(spt.inner) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error { + // need to determine the token type + raw := map[string]interface{}{} + err := json.Unmarshal(data, &raw) + if err != nil { + return err + } + secret := raw["secret"].(map[string]interface{}) + switch secret["type"] { + case "ServicePrincipalNoSecret": + spt.inner.Secret = &ServicePrincipalNoSecret{} + case "ServicePrincipalTokenSecret": + spt.inner.Secret = &ServicePrincipalTokenSecret{} + case "ServicePrincipalCertificateSecret": + return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported") + case "ServicePrincipalMSISecret": + return errors.New("unmarshalling ServicePrincipalMSISecret is not supported") + case "ServicePrincipalUsernamePasswordSecret": + spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{} + case "ServicePrincipalAuthorizationCodeSecret": + spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{} + default: + return fmt.Errorf("unrecognized token type '%s'", secret["type"]) + } + err = json.Unmarshal(data, &spt.inner) + if err != nil { + return err + } + spt.refreshLock = &sync.RWMutex{} + spt.sender = &http.Client{} + return nil +} + +// internal type used for marshalling/unmarshalling +type servicePrincipalToken struct { + Token Token `json:"token"` + Secret ServicePrincipalSecret `json:"secret"` + OauthConfig OAuthConfig `json:"oauth"` + ClientID string `json:"clientID"` + Resource string `json:"resource"` + AutoRefresh bool `json:"autoRefresh"` + RefreshWithin time.Duration `json:"refreshWithin"` } func validateOAuthConfig(oac OAuthConfig) error { @@ -281,13 +413,15 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso return nil, fmt.Errorf("parameter 'secret' cannot be nil") } spt := &ServicePrincipalToken{ - oauthConfig: oauthConfig, - secret: secret, - clientID: id, - resource: resource, - autoRefresh: true, + inner: servicePrincipalToken{ + OauthConfig: oauthConfig, + Secret: secret, + ClientID: id, + Resource: resource, + AutoRefresh: true, + RefreshWithin: defaultRefresh, + }, refreshLock: &sync.RWMutex{}, - refreshWithin: defaultRefresh, sender: &http.Client{}, refreshCallbacks: callbacks, } @@ -318,7 +452,39 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s return nil, err } - spt.token = token + spt.inner.Token = token + + return spt, nil +} + +// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret +func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { + if err := validateOAuthConfig(oauthConfig); err != nil { + return nil, err + } + if err := validateStringParam(clientID, "clientID"); err != nil { + return nil, err + } + if err := validateStringParam(resource, "resource"); err != nil { + return nil, err + } + if secret == nil { + return nil, fmt.Errorf("parameter 'secret' cannot be nil") + } + if token.IsZero() { + return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") + } + spt, err := NewServicePrincipalTokenWithSecret( + oauthConfig, + clientID, + resource, + secret, + callbacks...) + if err != nil { + return nil, err + } + + spt.inner.Token = token return spt, nil } @@ -486,20 +652,23 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI msiEndpointURL.RawQuery = v.Encode() spt := &ServicePrincipalToken{ - oauthConfig: OAuthConfig{ - TokenEndpoint: *msiEndpointURL, + inner: servicePrincipalToken{ + OauthConfig: OAuthConfig{ + TokenEndpoint: *msiEndpointURL, + }, + Secret: &ServicePrincipalMSISecret{}, + Resource: resource, + AutoRefresh: true, + RefreshWithin: defaultRefresh, }, - secret: &ServicePrincipalMSISecret{}, - resource: resource, - autoRefresh: true, - refreshLock: &sync.RWMutex{}, - refreshWithin: defaultRefresh, - sender: &http.Client{}, - refreshCallbacks: callbacks, + refreshLock: &sync.RWMutex{}, + sender: &http.Client{}, + refreshCallbacks: callbacks, + MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts, } if userAssignedID != nil { - spt.clientID = *userAssignedID + spt.inner.ClientID = *userAssignedID } return spt, nil @@ -528,12 +697,18 @@ func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError // EnsureFresh will refresh the token if it will expire within the refresh window (as set by // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. func (spt *ServicePrincipalToken) EnsureFresh() error { - if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) { + return spt.EnsureFreshWithContext(context.Background()) +} + +// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by +// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. +func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { + if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { // take the write lock then check to see if the token was already refreshed spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - if spt.token.WillExpireIn(spt.refreshWithin) { - return spt.refreshInternal(spt.resource) + if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { + return spt.refreshInternal(ctx, spt.inner.Resource) } } return nil @@ -543,7 +718,7 @@ func (spt *ServicePrincipalToken) EnsureFresh() error { func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { if spt.refreshCallbacks != nil { for _, callback := range spt.refreshCallbacks { - err := callback(spt.token) + err := callback(spt.inner.Token) if err != nil { return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err) } @@ -555,21 +730,33 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { // Refresh obtains a fresh token for the Service Principal. // This method is not safe for concurrent use and should be syncrhonized. func (spt *ServicePrincipalToken) Refresh() error { + return spt.RefreshWithContext(context.Background()) +} + +// RefreshWithContext obtains a fresh token for the Service Principal. +// This method is not safe for concurrent use and should be syncrhonized. +func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - return spt.refreshInternal(spt.resource) + return spt.refreshInternal(ctx, spt.inner.Resource) } // RefreshExchange refreshes the token, but for a different resource. // This method is not safe for concurrent use and should be syncrhonized. func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { + return spt.RefreshExchangeWithContext(context.Background(), resource) +} + +// RefreshExchangeWithContext refreshes the token, but for a different resource. +// This method is not safe for concurrent use and should be syncrhonized. +func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - return spt.refreshInternal(resource) + return spt.refreshInternal(ctx, resource) } func (spt *ServicePrincipalToken) getGrantType() string { - switch spt.secret.(type) { + switch spt.inner.Secret.(type) { case *ServicePrincipalUsernamePasswordSecret: return OAuthGrantTypeUserPass case *ServicePrincipalAuthorizationCodeSecret: @@ -587,23 +774,32 @@ func isIMDS(u url.URL) bool { return u.Host == imds.Host && u.Path == imds.Path } -func (spt *ServicePrincipalToken) refreshInternal(resource string) error { - req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), nil) +func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { + req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil) if err != nil { return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) } - - if !isIMDS(spt.oauthConfig.TokenEndpoint) { + req.Header.Add("User-Agent", version.UserAgent()) + req = req.WithContext(ctx) + if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) { v := url.Values{} - v.Set("client_id", spt.clientID) + v.Set("client_id", spt.inner.ClientID) v.Set("resource", resource) - if spt.token.RefreshToken != "" { + if spt.inner.Token.RefreshToken != "" { v.Set("grant_type", OAuthGrantTypeRefreshToken) - v.Set("refresh_token", spt.token.RefreshToken) + v.Set("refresh_token", spt.inner.Token.RefreshToken) + // web apps must specify client_secret when refreshing tokens + // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens + if spt.getGrantType() == OAuthGrantTypeAuthorizationCode { + err := spt.inner.Secret.SetAuthenticationValues(spt, &v) + if err != nil { + return err + } + } } else { v.Set("grant_type", spt.getGrantType()) - err := spt.secret.SetAuthenticationValues(spt, &v) + err := spt.inner.Secret.SetAuthenticationValues(spt, &v) if err != nil { return err } @@ -616,19 +812,19 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error { req.Body = body } - if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok { + if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok { req.Method = http.MethodGet req.Header.Set(metadataHeader, "true") } var resp *http.Response - if isIMDS(spt.oauthConfig.TokenEndpoint) { - resp, err = retry(spt.sender, req) + if isIMDS(spt.inner.OauthConfig.TokenEndpoint) { + resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts) } else { resp, err = spt.sender.Do(req) } if err != nil { - return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err) + return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil) } defer resp.Body.Close() @@ -636,11 +832,15 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error { if resp.StatusCode != http.StatusOK { if err != nil { - return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode), resp) + return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp) } return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp) } + // for the following error cases don't return a TokenRefreshError. the operation succeeded + // but some transient failure happened during deserialization. by returning a generic error + // the retry logic will kick in (we don't retry on TokenRefreshError). + if err != nil { return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err) } @@ -653,12 +853,14 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error { return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)) } - spt.token = token + spt.inner.Token = token return spt.InvokeRefreshCallbacks(token) } -func retry(sender Sender, req *http.Request) (resp *http.Response, err error) { +// retry logic specific to retrieving a token from the IMDS endpoint +func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) { + // copied from client.go due to circular dependency retries := []int{ http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 @@ -667,8 +869,10 @@ func retry(sender Sender, req *http.Request) (resp *http.Response, err error) { http.StatusServiceUnavailable, // 503 http.StatusGatewayTimeout, // 504 } - // Extra retry status codes requered - retries = append(retries, http.StatusNotFound, + // extra retry status codes specific to IMDS + retries = append(retries, + http.StatusNotFound, + http.StatusGone, // all remaining 5xx http.StatusNotImplemented, http.StatusHTTPVersionNotSupported, @@ -678,34 +882,52 @@ func retry(sender Sender, req *http.Request) (resp *http.Response, err error) { http.StatusNotExtended, http.StatusNetworkAuthenticationRequired) + // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance + + const maxDelay time.Duration = 60 * time.Second + attempt := 0 - maxAttempts := 5 + delay := time.Duration(0) for attempt < maxAttempts { resp, err = sender.Do(req) - if err != nil { + // retry on temporary network errors, e.g. transient network failures. + // if we don't receive a response then assume we can't connect to the + // endpoint so we're likely not running on an Azure VM so don't retry. + if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) { return } - if resp.StatusCode == http.StatusOK { - return + // perform exponential backoff with a cap. + // must increment attempt before calculating delay. + attempt++ + // the base value of 2 is the "delta backoff" as specified in the guidance doc + delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second) + if delay > maxDelay { + delay = maxDelay } - if containsInt(retries, resp.StatusCode) { - delayed := false - if resp.StatusCode == http.StatusTooManyRequests { - delayed = delay(resp, req.Cancel) - } - if !delayed { - time.Sleep(time.Second) - attempt++ - } - } else { + + select { + case <-time.After(delay): + // intentionally left blank + case <-req.Context().Done(): + err = req.Context().Err() return } } return } +// returns true if the specified error is a temporary network error or false if it's not. +// if the error doesn't implement the net.Error interface the return value is true. +func isTemporaryNetworkError(err error) bool { + if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) { + return true + } + return false +} + +// returns true if slice ints contains the value n func containsInt(ints []int, n int) bool { for _, i := range ints { if i == n { @@ -715,31 +937,15 @@ func containsInt(ints []int, n int) bool { return false } -func delay(resp *http.Response, cancel <-chan struct{}) bool { - if resp == nil { - return false - } - retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After")) - if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 { - select { - case <-time.After(time.Duration(retryAfter) * time.Second): - return true - case <-cancel: - return false - } - } - return false -} - // SetAutoRefresh enables or disables automatic refreshing of stale tokens. func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) { - spt.autoRefresh = autoRefresh + spt.inner.AutoRefresh = autoRefresh } // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will // refresh the token. func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) { - spt.refreshWithin = d + spt.inner.RefreshWithin = d return } @@ -751,12 +957,12 @@ func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s } func (spt *ServicePrincipalToken) OAuthToken() string { spt.refreshLock.RLock() defer spt.refreshLock.RUnlock() - return spt.token.OAuthToken() + return spt.inner.Token.OAuthToken() } // Token returns a copy of the current token. func (spt *ServicePrincipalToken) Token() Token { spt.refreshLock.RLock() defer spt.refreshLock.RUnlock() - return spt.token + return spt.inner.Token } diff --git a/vendor/github.com/Azure/go-autorest/autorest/authorization.go b/vendor/github.com/Azure/go-autorest/autorest/authorization.go index c51eac0a..77eff45b 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/authorization.go +++ b/vendor/github.com/Azure/go-autorest/autorest/authorization.go @@ -113,17 +113,19 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator { return PreparerFunc(func(r *http.Request) (*http.Request, error) { r, err := p.Prepare(r) if err == nil { - refresher, ok := ba.tokenProvider.(adal.Refresher) - if ok { - err := refresher.EnsureFresh() - if err != nil { - var resp *http.Response - if tokError, ok := err.(adal.TokenRefreshError); ok { - resp = tokError.Response() - } - return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp, - "Failed to refresh the Token for request to %s", r.URL) + // the ordering is important here, prefer RefresherWithContext if available + if refresher, ok := ba.tokenProvider.(adal.RefresherWithContext); ok { + err = refresher.EnsureFreshWithContext(r.Context()) + } else if refresher, ok := ba.tokenProvider.(adal.Refresher); ok { + err = refresher.EnsureFresh() + } + if err != nil { + var resp *http.Response + if tokError, ok := err.(adal.TokenRefreshError); ok { + resp = tokError.Response() } + return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp, + "Failed to refresh the Token for request to %s", r.URL) } return Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken()))) } diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/async.go b/vendor/github.com/Azure/go-autorest/autorest/azure/async.go index a58e5ef3..cda1e180 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/async.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/async.go @@ -21,11 +21,11 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "strings" "time" "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/date" ) const ( @@ -44,84 +44,85 @@ var pollingCodes = [...]int{http.StatusNoContent, http.StatusAccepted, http.Stat // Future provides a mechanism to access the status and results of an asynchronous request. // Since futures are stateful they should be passed by value to avoid race conditions. type Future struct { - req *http.Request - resp *http.Response - ps pollingState + req *http.Request // legacy + pt pollingTracker } // NewFuture returns a new Future object initialized with the specified request. +// Deprecated: Please use NewFutureFromResponse instead. func NewFuture(req *http.Request) Future { return Future{req: req} } -// Response returns the last HTTP response or nil if there isn't one. +// NewFutureFromResponse returns a new Future object initialized +// with the initial response from an asynchronous operation. +func NewFutureFromResponse(resp *http.Response) (Future, error) { + pt, err := createPollingTracker(resp) + if err != nil { + return Future{}, err + } + return Future{pt: pt}, nil +} + +// Response returns the last HTTP response. func (f Future) Response() *http.Response { - return f.resp + if f.pt == nil { + return nil + } + return f.pt.latestResponse() } // Status returns the last status message of the operation. func (f Future) Status() string { - if f.ps.State == "" { - return "Unknown" + if f.pt == nil { + return "" } - return f.ps.State + return f.pt.pollingStatus() } // PollingMethod returns the method used to monitor the status of the asynchronous operation. func (f Future) PollingMethod() PollingMethodType { - return f.ps.PollingMethod + if f.pt == nil { + return PollingUnknown + } + return f.pt.pollingMethod() } // Done queries the service to see if the operation has completed. func (f *Future) Done(sender autorest.Sender) (bool, error) { - // exit early if this future has terminated - if f.ps.hasTerminated() { - return true, f.errorInfo() + // support for legacy Future implementation + if f.req != nil { + resp, err := sender.Do(f.req) + if err != nil { + return false, err + } + pt, err := createPollingTracker(resp) + if err != nil { + return false, err + } + f.pt = pt + f.req = nil } - resp, err := sender.Do(f.req) - f.resp = resp - if err != nil { + // end legacy + if f.pt == nil { + return false, autorest.NewError("Future", "Done", "future is not initialized") + } + if f.pt.hasTerminated() { + return true, f.pt.pollingError() + } + if err := f.pt.pollForStatus(sender); err != nil { return false, err } - - if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) { - // check response body for error content - if resp.Body != nil { - type respErr struct { - ServiceError ServiceError `json:"error"` - } - re := respErr{} - - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return false, err - } - err = json.Unmarshal(b, &re) - if err != nil { - return false, err - } - return false, re.ServiceError - } - - // try to return something meaningful - return false, ServiceError{ - Code: fmt.Sprintf("%v", resp.StatusCode), - Message: resp.Status, - } + if err := f.pt.checkForErrors(); err != nil { + return f.pt.hasTerminated(), err } - - err = updatePollingState(resp, &f.ps) - if err != nil { + if err := f.pt.updatePollingState(f.pt.provisioningStateApplicable()); err != nil { return false, err } - - if f.ps.hasTerminated() { - return true, f.errorInfo() + if err := f.pt.updateHeaders(); err != nil { + return false, err } - - f.req, err = newPollingRequest(f.ps) - return false, err + return f.pt.hasTerminated(), f.pt.pollingError() } // GetPollingDelay returns a duration the application should wait before checking @@ -129,11 +130,15 @@ func (f *Future) Done(sender autorest.Sender) (bool, error) { // the service via the Retry-After response header. If the header wasn't returned // then the function returns the zero-value time.Duration and false. func (f Future) GetPollingDelay() (time.Duration, bool) { - if f.resp == nil { + if f.pt == nil { + return 0, false + } + resp := f.pt.latestResponse() + if resp == nil { return 0, false } - retry := f.resp.Header.Get(autorest.HeaderRetryAfter) + retry := resp.Header.Get(autorest.HeaderRetryAfter) if retry == "" { return 0, false } @@ -150,14 +155,22 @@ func (f Future) GetPollingDelay() (time.Duration, bool) { // running operation has completed, the provided context is cancelled, or the client's // polling duration has been exceeded. It will retry failed polling attempts based on // the retry value defined in the client up to the maximum retry attempts. +// Deprecated: Please use WaitForCompletionRef() instead. func (f Future) WaitForCompletion(ctx context.Context, client autorest.Client) error { + return f.WaitForCompletionRef(ctx, client) +} + +// WaitForCompletionRef will return when one of the following conditions is met: the long +// running operation has completed, the provided context is cancelled, or the client's +// polling duration has been exceeded. It will retry failed polling attempts based on +// the retry value defined in the client up to the maximum retry attempts. +func (f *Future) WaitForCompletionRef(ctx context.Context, client autorest.Client) error { ctx, cancel := context.WithTimeout(ctx, client.PollingDuration) defer cancel() - done, err := f.Done(client) for attempts := 0; !done; done, err = f.Done(client) { if attempts >= client.RetryAttempts { - return autorest.NewErrorWithError(err, "azure", "WaitForCompletion", f.resp, "the number of retries has been exceeded") + return autorest.NewErrorWithError(err, "Future", "WaitForCompletion", f.pt.latestResponse(), "the number of retries has been exceeded") } // we want delayAttempt to be zero in the non-error case so // that DelayForBackoff doesn't perform exponential back-off @@ -181,162 +194,692 @@ func (f Future) WaitForCompletion(ctx context.Context, client autorest.Client) e // wait until the delay elapses or the context is cancelled delayElapsed := autorest.DelayForBackoff(delay, delayAttempt, ctx.Done()) if !delayElapsed { - return autorest.NewErrorWithError(ctx.Err(), "azure", "WaitForCompletion", f.resp, "context has been cancelled") + return autorest.NewErrorWithError(ctx.Err(), "Future", "WaitForCompletion", f.pt.latestResponse(), "context has been cancelled") } } return err } -// if the operation failed the polling state will contain -// error information and implements the error interface -func (f *Future) errorInfo() error { - if !f.ps.hasSucceeded() { - return f.ps - } - return nil -} - // MarshalJSON implements the json.Marshaler interface. func (f Future) MarshalJSON() ([]byte, error) { - return json.Marshal(&f.ps) + return json.Marshal(f.pt) } // UnmarshalJSON implements the json.Unmarshaler interface. func (f *Future) UnmarshalJSON(data []byte) error { - err := json.Unmarshal(data, &f.ps) + // unmarshal into JSON object to determine the tracker type + obj := map[string]interface{}{} + err := json.Unmarshal(data, &obj) if err != nil { return err } - f.req, err = newPollingRequest(f.ps) - return err + if obj["method"] == nil { + return autorest.NewError("Future", "UnmarshalJSON", "missing 'method' property") + } + method := obj["method"].(string) + switch strings.ToUpper(method) { + case http.MethodDelete: + f.pt = &pollingTrackerDelete{} + case http.MethodPatch: + f.pt = &pollingTrackerPatch{} + case http.MethodPost: + f.pt = &pollingTrackerPost{} + case http.MethodPut: + f.pt = &pollingTrackerPut{} + default: + return autorest.NewError("Future", "UnmarshalJSON", "unsupoorted method '%s'", method) + } + // now unmarshal into the tracker + return json.Unmarshal(data, &f.pt) } // PollingURL returns the URL used for retrieving the status of the long-running operation. -// For LROs that use the Location header the final URL value is used to retrieve the result. func (f Future) PollingURL() string { - return f.ps.URI + if f.pt == nil { + return "" + } + return f.pt.pollingURL() +} + +// GetResult should be called once polling has completed successfully. +// It makes the final GET call to retrieve the resultant payload. +func (f Future) GetResult(sender autorest.Sender) (*http.Response, error) { + if f.pt.finalGetURL() == "" { + // we can end up in this situation if the async operation returns a 200 + // with no polling URLs. in that case return the response which should + // contain the JSON payload (only do this for successful terminal cases). + if lr := f.pt.latestResponse(); lr != nil && f.pt.hasSucceeded() { + return lr, nil + } + return nil, autorest.NewError("Future", "GetResult", "missing URL for retrieving result") + } + req, err := http.NewRequest(http.MethodGet, f.pt.finalGetURL(), nil) + if err != nil { + return nil, err + } + return sender.Do(req) +} + +type pollingTracker interface { + // these methods can differ per tracker + + // checks the response headers and status code to determine the polling mechanism + updateHeaders() error + + // checks the response for tracker-specific error conditions + checkForErrors() error + + // returns true if provisioning state should be checked + provisioningStateApplicable() bool + + // methods common to all trackers + + // initializes the tracker's internal state, call this when the tracker is created + initializeState() error + + // makes an HTTP request to check the status of the LRO + pollForStatus(sender autorest.Sender) error + + // updates internal tracker state, call this after each call to pollForStatus + updatePollingState(provStateApl bool) error + + // returns the error response from the service, can be nil + pollingError() error + + // returns the polling method being used + pollingMethod() PollingMethodType + + // returns the state of the LRO as returned from the service + pollingStatus() string + + // returns the URL used for polling status + pollingURL() string + + // returns the URL used for the final GET to retrieve the resource + finalGetURL() string + + // returns true if the LRO is in a terminal state + hasTerminated() bool + + // returns true if the LRO is in a failed terminal state + hasFailed() bool + + // returns true if the LRO is in a successful terminal state + hasSucceeded() bool + + // returns the cached HTTP response after a call to pollForStatus(), can be nil + latestResponse() *http.Response +} + +type pollingTrackerBase struct { + // resp is the last response, either from the submission of the LRO or from polling + resp *http.Response + + // method is the HTTP verb, this is needed for deserialization + Method string `json:"method"` + + // rawBody is the raw JSON response body + rawBody map[string]interface{} + + // denotes if polling is using async-operation or location header + Pm PollingMethodType `json:"pollingMethod"` + + // the URL to poll for status + URI string `json:"pollingURI"` + + // the state of the LRO as returned from the service + State string `json:"lroState"` + + // the URL to GET for the final result + FinalGetURI string `json:"resultURI"` + + // used to hold an error object returned from the service + Err *ServiceError `json:"error,omitempty"` +} + +func (pt *pollingTrackerBase) initializeState() error { + // determine the initial polling state based on response body and/or HTTP status + // code. this is applicable to the initial LRO response, not polling responses! + pt.Method = pt.resp.Request.Method + if err := pt.updateRawBody(); err != nil { + return err + } + switch pt.resp.StatusCode { + case http.StatusOK: + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + } else { + pt.State = operationSucceeded + } + case http.StatusCreated: + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + } else { + pt.State = operationInProgress + } + case http.StatusAccepted: + pt.State = operationInProgress + case http.StatusNoContent: + pt.State = operationSucceeded + default: + pt.State = operationFailed + pt.updateErrorFromResponse() + } + return nil +} + +func (pt pollingTrackerBase) getProvisioningState() *string { + if pt.rawBody != nil && pt.rawBody["properties"] != nil { + p := pt.rawBody["properties"].(map[string]interface{}) + if ps := p["provisioningState"]; ps != nil { + s := ps.(string) + return &s + } + } + return nil +} + +func (pt *pollingTrackerBase) updateRawBody() error { + pt.rawBody = map[string]interface{}{} + if pt.resp.ContentLength != 0 { + defer pt.resp.Body.Close() + b, err := ioutil.ReadAll(pt.resp.Body) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "updateRawBody", nil, "failed to read response body") + } + // put the body back so it's available to other callers + pt.resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + if err = json.Unmarshal(b, &pt.rawBody); err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "updateRawBody", nil, "failed to unmarshal response body") + } + } + return nil +} + +func (pt *pollingTrackerBase) pollForStatus(sender autorest.Sender) error { + req, err := http.NewRequest(http.MethodGet, pt.URI, nil) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "pollForStatus", nil, "failed to create HTTP request") + } + // attach the context from the original request if available (it will be absent for deserialized futures) + if pt.resp != nil { + req = req.WithContext(pt.resp.Request.Context()) + } + pt.resp, err = sender.Do(req) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "pollForStatus", nil, "failed to send HTTP request") + } + if autorest.ResponseHasStatusCode(pt.resp, pollingCodes[:]...) { + // reset the service error on success case + pt.Err = nil + err = pt.updateRawBody() + } else { + // check response body for error content + pt.updateErrorFromResponse() + } + return err +} + +// attempts to unmarshal a ServiceError type from the response body. +// if that fails then make a best attempt at creating something meaningful. +func (pt *pollingTrackerBase) updateErrorFromResponse() { + var err error + if pt.resp.ContentLength != 0 { + type respErr struct { + ServiceError *ServiceError `json:"error"` + } + re := respErr{} + defer pt.resp.Body.Close() + var b []byte + b, err = ioutil.ReadAll(pt.resp.Body) + if err != nil { + goto Default + } + if err = json.Unmarshal(b, &re); err != nil { + goto Default + } + // unmarshalling the error didn't yield anything, try unwrapped error + if re.ServiceError == nil { + err = json.Unmarshal(b, &re.ServiceError) + if err != nil { + goto Default + } + } + if re.ServiceError != nil { + pt.Err = re.ServiceError + return + } + } +Default: + se := &ServiceError{ + Code: fmt.Sprintf("HTTP status code %v", pt.resp.StatusCode), + Message: pt.resp.Status, + } + if err != nil { + se.InnerError = make(map[string]interface{}) + se.InnerError["unmarshalError"] = err.Error() + } + pt.Err = se +} + +func (pt *pollingTrackerBase) updatePollingState(provStateApl bool) error { + if pt.Pm == PollingAsyncOperation && pt.rawBody["status"] != nil { + pt.State = pt.rawBody["status"].(string) + } else { + if pt.resp.StatusCode == http.StatusAccepted { + pt.State = operationInProgress + } else if provStateApl { + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + } else { + pt.State = operationSucceeded + } + } else { + return autorest.NewError("pollingTrackerBase", "updatePollingState", "the response from the async operation has an invalid status code") + } + } + // if the operation has failed update the error state + if pt.hasFailed() { + pt.updateErrorFromResponse() + } + return nil +} + +func (pt pollingTrackerBase) pollingError() error { + if pt.Err == nil { + return nil + } + return pt.Err +} + +func (pt pollingTrackerBase) pollingMethod() PollingMethodType { + return pt.Pm +} + +func (pt pollingTrackerBase) pollingStatus() string { + return pt.State +} + +func (pt pollingTrackerBase) pollingURL() string { + return pt.URI +} + +func (pt pollingTrackerBase) finalGetURL() string { + return pt.FinalGetURI +} + +func (pt pollingTrackerBase) hasTerminated() bool { + return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) || strings.EqualFold(pt.State, operationSucceeded) +} + +func (pt pollingTrackerBase) hasFailed() bool { + return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) +} + +func (pt pollingTrackerBase) hasSucceeded() bool { + return strings.EqualFold(pt.State, operationSucceeded) +} + +func (pt pollingTrackerBase) latestResponse() *http.Response { + return pt.resp +} + +// error checking common to all trackers +func (pt pollingTrackerBase) baseCheckForErrors() error { + // for Azure-AsyncOperations the response body cannot be nil or empty + if pt.Pm == PollingAsyncOperation { + if pt.resp.Body == nil || pt.resp.ContentLength == 0 { + return autorest.NewError("pollingTrackerBase", "baseCheckForErrors", "for Azure-AsyncOperation response body cannot be nil") + } + if pt.rawBody["status"] == nil { + return autorest.NewError("pollingTrackerBase", "baseCheckForErrors", "missing status property in Azure-AsyncOperation response body") + } + } + return nil +} + +// DELETE + +type pollingTrackerDelete struct { + pollingTrackerBase +} + +func (pt *pollingTrackerDelete) updateHeaders() error { + // for 201 the Location header is required + if pt.resp.StatusCode == http.StatusCreated { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerDelete", "updateHeaders", "missing Location header in 201 response") + } else { + pt.URI = lh + } + pt.Pm = PollingLocation + pt.FinalGetURI = pt.URI + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation + } + // when both headers are returned we use the value in the Location header for the final GET + pt.FinalGetURI = lh + } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } + } + return nil +} + +func (pt pollingTrackerDelete) checkForErrors() error { + return pt.baseCheckForErrors() +} + +func (pt pollingTrackerDelete) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent +} + +// PATCH + +type pollingTrackerPatch struct { + pollingTrackerBase +} + +func (pt *pollingTrackerPatch) updateHeaders() error { + // by default we can use the original URL for polling and final GET + if pt.URI == "" { + pt.URI = pt.resp.Request.URL.String() + } + if pt.FinalGetURI == "" { + pt.FinalGetURI = pt.resp.Request.URL.String() + } + if pt.Pm == PollingUnknown { + pt.Pm = PollingRequestURI + } + // for 201 it's permissible for no headers to be returned + if pt.resp.StatusCode == http.StatusCreated { + if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + // note the absense of the "final GET" mechanism for PATCH + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + if ao == "" { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerPatch", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } else { + pt.URI = lh + pt.Pm = PollingLocation + } + } + } + return nil +} + +func (pt pollingTrackerPatch) checkForErrors() error { + return pt.baseCheckForErrors() +} + +func (pt pollingTrackerPatch) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated +} + +// POST + +type pollingTrackerPost struct { + pollingTrackerBase +} + +func (pt *pollingTrackerPost) updateHeaders() error { + // 201 requires Location header + if pt.resp.StatusCode == http.StatusCreated { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "missing Location header in 201 response") + } else { + pt.URI = lh + pt.FinalGetURI = lh + pt.Pm = PollingLocation + } + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation + } + // when both headers are returned we use the value in the Location header for the final GET + pt.FinalGetURI = lh + } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } + } + return nil +} + +func (pt pollingTrackerPost) checkForErrors() error { + return pt.baseCheckForErrors() +} + +func (pt pollingTrackerPost) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent +} + +// PUT + +type pollingTrackerPut struct { + pollingTrackerBase +} + +func (pt *pollingTrackerPut) updateHeaders() error { + // by default we can use the original URL for polling and final GET + if pt.URI == "" { + pt.URI = pt.resp.Request.URL.String() + } + if pt.FinalGetURI == "" { + pt.FinalGetURI = pt.resp.Request.URL.String() + } + if pt.Pm == PollingUnknown { + pt.Pm = PollingRequestURI + } + // for 201 it's permissible for no headers to be returned + if pt.resp.StatusCode == http.StatusCreated { + if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation + } + // when both headers are returned we use the value in the Location header for the final GET + pt.FinalGetURI = lh + } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPut", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } + } + return nil +} + +func (pt pollingTrackerPut) checkForErrors() error { + err := pt.baseCheckForErrors() + if err != nil { + return err + } + // if there are no LRO headers then the body cannot be empty + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } + lh, err := getURLFromLocationHeader(pt.resp) + if err != nil { + return err + } + if ao == "" && lh == "" && len(pt.rawBody) == 0 { + return autorest.NewError("pollingTrackerPut", "checkForErrors", "the response did not contain a body") + } + return nil +} + +func (pt pollingTrackerPut) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated +} + +// creates a polling tracker based on the verb of the original request +func createPollingTracker(resp *http.Response) (pollingTracker, error) { + var pt pollingTracker + switch strings.ToUpper(resp.Request.Method) { + case http.MethodDelete: + pt = &pollingTrackerDelete{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPatch: + pt = &pollingTrackerPatch{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPost: + pt = &pollingTrackerPost{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPut: + pt = &pollingTrackerPut{pollingTrackerBase: pollingTrackerBase{resp: resp}} + default: + return nil, autorest.NewError("azure", "createPollingTracker", "unsupported HTTP method %s", resp.Request.Method) + } + if err := pt.initializeState(); err != nil { + return pt, err + } + // this initializes the polling header values, we do this during creation in case the + // initial response send us invalid values; this way the API call will return a non-nil + // error (not doing this means the error shows up in Future.Done) + return pt, pt.updateHeaders() +} + +// gets the polling URL from the Azure-AsyncOperation header. +// ensures the URL is well-formed and absolute. +func getURLFromAsyncOpHeader(resp *http.Response) (string, error) { + s := resp.Header.Get(http.CanonicalHeaderKey(headerAsyncOperation)) + if s == "" { + return "", nil + } + if !isValidURL(s) { + return "", autorest.NewError("azure", "getURLFromAsyncOpHeader", "invalid polling URL '%s'", s) + } + return s, nil +} + +// gets the polling URL from the Location header. +// ensures the URL is well-formed and absolute. +func getURLFromLocationHeader(resp *http.Response) (string, error) { + s := resp.Header.Get(http.CanonicalHeaderKey(autorest.HeaderLocation)) + if s == "" { + return "", nil + } + if !isValidURL(s) { + return "", autorest.NewError("azure", "getURLFromLocationHeader", "invalid polling URL '%s'", s) + } + return s, nil +} + +// verify that the URL is valid and absolute +func isValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() } // DoPollForAsynchronous returns a SendDecorator that polls if the http.Response is for an Azure // long-running operation. It will delay between requests for the duration specified in the -// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by -// closing the optional channel on the http.Request. +// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled via +// the context associated with the http.Request. +// Deprecated: Prefer using Futures to allow for non-blocking async operations. func DoPollForAsynchronous(delay time.Duration) autorest.SendDecorator { return func(s autorest.Sender) autorest.Sender { - return autorest.SenderFunc(func(r *http.Request) (resp *http.Response, err error) { - resp, err = s.Do(r) + return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) { + resp, err := s.Do(r) if err != nil { return resp, err } if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) { return resp, nil } - - ps := pollingState{} - for err == nil { - err = updatePollingState(resp, &ps) - if err != nil { - break - } - if ps.hasTerminated() { - if !ps.hasSucceeded() { - err = ps - } - break - } - - r, err = newPollingRequest(ps) - if err != nil { - return resp, err - } - r = r.WithContext(resp.Request.Context()) - - delay = autorest.GetRetryAfter(resp, delay) - resp, err = autorest.SendWithSender(s, r, - autorest.AfterDelay(delay)) + future, err := NewFutureFromResponse(resp) + if err != nil { + return resp, err } - - return resp, err + // retry until either the LRO completes or we receive an error + var done bool + for done, err = future.Done(s); !done && err == nil; done, err = future.Done(s) { + // check for Retry-After delay, if not present use the specified polling delay + if pd, ok := future.GetPollingDelay(); ok { + delay = pd + } + // wait until the delay elapses or the context is cancelled + if delayElapsed := autorest.DelayForBackoff(delay, 0, r.Context().Done()); !delayElapsed { + return future.Response(), + autorest.NewErrorWithError(r.Context().Err(), "azure", "DoPollForAsynchronous", future.Response(), "context has been cancelled") + } + } + return future.Response(), err }) } } -func getAsyncOperation(resp *http.Response) string { - return resp.Header.Get(http.CanonicalHeaderKey(headerAsyncOperation)) -} - -func hasSucceeded(state string) bool { - return strings.EqualFold(state, operationSucceeded) -} - -func hasTerminated(state string) bool { - return strings.EqualFold(state, operationCanceled) || strings.EqualFold(state, operationFailed) || strings.EqualFold(state, operationSucceeded) -} - -func hasFailed(state string) bool { - return strings.EqualFold(state, operationFailed) -} - -type provisioningTracker interface { - state() string - hasSucceeded() bool - hasTerminated() bool -} - -type operationResource struct { - // Note: - // The specification states services should return the "id" field. However some return it as - // "operationId". - ID string `json:"id"` - OperationID string `json:"operationId"` - Name string `json:"name"` - Status string `json:"status"` - Properties map[string]interface{} `json:"properties"` - OperationError ServiceError `json:"error"` - StartTime date.Time `json:"startTime"` - EndTime date.Time `json:"endTime"` - PercentComplete float64 `json:"percentComplete"` -} - -func (or operationResource) state() string { - return or.Status -} - -func (or operationResource) hasSucceeded() bool { - return hasSucceeded(or.state()) -} - -func (or operationResource) hasTerminated() bool { - return hasTerminated(or.state()) -} - -type provisioningProperties struct { - ProvisioningState string `json:"provisioningState"` -} - -type provisioningStatus struct { - Properties provisioningProperties `json:"properties,omitempty"` - ProvisioningError ServiceError `json:"error,omitempty"` -} - -func (ps provisioningStatus) state() string { - return ps.Properties.ProvisioningState -} - -func (ps provisioningStatus) hasSucceeded() bool { - return hasSucceeded(ps.state()) -} - -func (ps provisioningStatus) hasTerminated() bool { - return hasTerminated(ps.state()) -} - -func (ps provisioningStatus) hasProvisioningError() bool { - // code and message are required fields so only check them - return len(ps.ProvisioningError.Code) > 0 || - len(ps.ProvisioningError.Message) > 0 -} - // PollingMethodType defines a type used for enumerating polling mechanisms. type PollingMethodType string @@ -347,151 +890,13 @@ const ( // PollingLocation indicates the polling method uses the Location header. PollingLocation PollingMethodType = "Location" + // PollingRequestURI indicates the polling method uses the original request URI. + PollingRequestURI PollingMethodType = "RequestURI" + // PollingUnknown indicates an unknown polling method and is the default value. PollingUnknown PollingMethodType = "" ) -type pollingState struct { - PollingMethod PollingMethodType `json:"pollingMethod"` - URI string `json:"uri"` - State string `json:"state"` - ServiceError *ServiceError `json:"error,omitempty"` -} - -func (ps pollingState) hasSucceeded() bool { - return hasSucceeded(ps.State) -} - -func (ps pollingState) hasTerminated() bool { - return hasTerminated(ps.State) -} - -func (ps pollingState) hasFailed() bool { - return hasFailed(ps.State) -} - -func (ps pollingState) Error() string { - s := fmt.Sprintf("Long running operation terminated with status '%s'", ps.State) - if ps.ServiceError != nil { - s = fmt.Sprintf("%s: %+v", s, *ps.ServiceError) - } - return s -} - -// updatePollingState maps the operation status -- retrieved from either a provisioningState -// field, the status field of an OperationResource, or inferred from the HTTP status code -- -// into a well-known states. Since the process begins from the initial request, the state -// always comes from either a the provisioningState returned or is inferred from the HTTP -// status code. Subsequent requests will read an Azure OperationResource object if the -// service initially returned the Azure-AsyncOperation header. The responseFormat field notes -// the expected response format. -func updatePollingState(resp *http.Response, ps *pollingState) error { - // Determine the response shape - // -- The first response will always be a provisioningStatus response; only the polling requests, - // depending on the header returned, may be something otherwise. - var pt provisioningTracker - if ps.PollingMethod == PollingAsyncOperation { - pt = &operationResource{} - } else { - pt = &provisioningStatus{} - } - - // If this is the first request (that is, the polling response shape is unknown), determine how - // to poll and what to expect - if ps.PollingMethod == PollingUnknown { - req := resp.Request - if req == nil { - return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Original HTTP request is missing") - } - - // Prefer the Azure-AsyncOperation header - ps.URI = getAsyncOperation(resp) - if ps.URI != "" { - ps.PollingMethod = PollingAsyncOperation - } else { - ps.PollingMethod = PollingLocation - } - - // Else, use the Location header - if ps.URI == "" { - ps.URI = autorest.GetLocation(resp) - } - - // Lastly, requests against an existing resource, use the last request URI - if ps.URI == "" { - m := strings.ToUpper(req.Method) - if m == http.MethodPatch || m == http.MethodPut || m == http.MethodGet { - ps.URI = req.URL.String() - } - } - } - - // Read and interpret the response (saving the Body in case no polling is necessary) - b := &bytes.Buffer{} - err := autorest.Respond(resp, - autorest.ByCopying(b), - autorest.ByUnmarshallingJSON(pt), - autorest.ByClosing()) - resp.Body = ioutil.NopCloser(b) - if err != nil { - return err - } - - // Interpret the results - // -- Terminal states apply regardless - // -- Unknown states are per-service inprogress states - // -- Otherwise, infer state from HTTP status code - if pt.hasTerminated() { - ps.State = pt.state() - } else if pt.state() != "" { - ps.State = operationInProgress - } else { - switch resp.StatusCode { - case http.StatusAccepted: - ps.State = operationInProgress - - case http.StatusNoContent, http.StatusCreated, http.StatusOK: - ps.State = operationSucceeded - - default: - ps.State = operationFailed - } - } - - if strings.EqualFold(ps.State, operationInProgress) && ps.URI == "" { - return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Unable to obtain polling URI for %s %s", resp.Request.Method, resp.Request.URL) - } - - // For failed operation, check for error code and message in - // -- Operation resource - // -- Response - // -- Otherwise, Unknown - if ps.hasFailed() { - if or, ok := pt.(*operationResource); ok { - ps.ServiceError = &or.OperationError - } else if p, ok := pt.(*provisioningStatus); ok && p.hasProvisioningError() { - ps.ServiceError = &p.ProvisioningError - } else { - ps.ServiceError = &ServiceError{ - Code: "Unknown", - Message: "None", - } - } - } - return nil -} - -func newPollingRequest(ps pollingState) (*http.Request, error) { - reqPoll, err := autorest.Prepare(&http.Request{}, - autorest.AsGet(), - autorest.WithBaseURL(ps.URI)) - if err != nil { - return nil, autorest.NewErrorWithError(err, "azure", "newPollingRequest", nil, "Failure creating poll request to %s", ps.URI) - } - - return reqPoll, nil -} - // AsyncOpIncompleteError is the type that's returned from a future that has not completed. type AsyncOpIncompleteError struct { // FutureType is the name of the type composed of a azure.Future. diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go b/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go index 18d02952..3a0a439f 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go @@ -44,11 +44,12 @@ const ( // ServiceError encapsulates the error response from an Azure service. // It adhears to the OData v4 specification for error responses. type ServiceError struct { - Code string `json:"code"` - Message string `json:"message"` - Target *string `json:"target"` - Details []map[string]interface{} `json:"details"` - InnerError map[string]interface{} `json:"innererror"` + Code string `json:"code"` + Message string `json:"message"` + Target *string `json:"target"` + Details []map[string]interface{} `json:"details"` + InnerError map[string]interface{} `json:"innererror"` + AdditionalInfo []map[string]interface{} `json:"additionalInfo"` } func (se ServiceError) Error() string { @@ -74,6 +75,14 @@ func (se ServiceError) Error() string { result += fmt.Sprintf(" InnerError=%v", string(d)) } + if se.AdditionalInfo != nil { + d, err := json.Marshal(se.AdditionalInfo) + if err != nil { + result += fmt.Sprintf(" AdditionalInfo=%v", se.AdditionalInfo) + } + result += fmt.Sprintf(" AdditionalInfo=%v", string(d)) + } + return result } @@ -86,44 +95,47 @@ func (se *ServiceError) UnmarshalJSON(b []byte) error { // http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091 type serviceError1 struct { - Code string `json:"code"` - Message string `json:"message"` - Target *string `json:"target"` - Details []map[string]interface{} `json:"details"` - InnerError map[string]interface{} `json:"innererror"` + Code string `json:"code"` + Message string `json:"message"` + Target *string `json:"target"` + Details []map[string]interface{} `json:"details"` + InnerError map[string]interface{} `json:"innererror"` + AdditionalInfo []map[string]interface{} `json:"additionalInfo"` } type serviceError2 struct { - Code string `json:"code"` - Message string `json:"message"` - Target *string `json:"target"` - Details map[string]interface{} `json:"details"` - InnerError map[string]interface{} `json:"innererror"` + Code string `json:"code"` + Message string `json:"message"` + Target *string `json:"target"` + Details map[string]interface{} `json:"details"` + InnerError map[string]interface{} `json:"innererror"` + AdditionalInfo []map[string]interface{} `json:"additionalInfo"` } se1 := serviceError1{} err := json.Unmarshal(b, &se1) if err == nil { - se.populate(se1.Code, se1.Message, se1.Target, se1.Details, se1.InnerError) + se.populate(se1.Code, se1.Message, se1.Target, se1.Details, se1.InnerError, se1.AdditionalInfo) return nil } se2 := serviceError2{} err = json.Unmarshal(b, &se2) if err == nil { - se.populate(se2.Code, se2.Message, se2.Target, nil, se2.InnerError) + se.populate(se2.Code, se2.Message, se2.Target, nil, se2.InnerError, se2.AdditionalInfo) se.Details = append(se.Details, se2.Details) return nil } return err } -func (se *ServiceError) populate(code, message string, target *string, details []map[string]interface{}, inner map[string]interface{}) { +func (se *ServiceError) populate(code, message string, target *string, details []map[string]interface{}, inner map[string]interface{}, additional []map[string]interface{}) { se.Code = code se.Message = message se.Target = target se.Details = details se.InnerError = inner + se.AdditionalInfo = additional } // RequestError describes an error response returned by Azure service. @@ -279,16 +291,29 @@ func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator { resp.Body = ioutil.NopCloser(&b) if decodeErr != nil { return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr) - } else if e.ServiceError == nil { + } + if e.ServiceError == nil { // Check if error is unwrapped ServiceError - if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil || e.ServiceError.Message == "" { - e.ServiceError = &ServiceError{ - Code: "Unknown", - Message: "Unknown service error", - } + if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil { + return err } } - + if e.ServiceError.Message == "" { + // if we're here it means the returned error wasn't OData v4 compliant. + // try to unmarshal the body as raw JSON in hopes of getting something. + rawBody := map[string]interface{}{} + if err := json.Unmarshal(b.Bytes(), &rawBody); err != nil { + return err + } + e.ServiceError = &ServiceError{ + Code: "Unknown", + Message: "Unknown service error", + } + if len(rawBody) > 0 { + e.ServiceError.Details = []map[string]interface{}{rawBody} + } + } + e.Response = resp e.RequestID = ExtractRequestID(resp) if e.StatusCode == nil { e.StatusCode = resp.StatusCode diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go b/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go index 65ad0afc..bd34f0ed 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go @@ -64,7 +64,7 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator { } } } - return resp, fmt.Errorf("failed request: %s", err) + return resp, err }) } } diff --git a/vendor/github.com/Azure/go-autorest/autorest/client.go b/vendor/github.com/Azure/go-autorest/autorest/client.go index 4e92dcad..467c1509 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/client.go +++ b/vendor/github.com/Azure/go-autorest/autorest/client.go @@ -22,8 +22,9 @@ import ( "log" "net/http" "net/http/cookiejar" - "runtime" "time" + + "github.com/Azure/go-autorest/version" ) const ( @@ -41,15 +42,6 @@ const ( ) var ( - // defaultUserAgent builds a string containing the Go version, system archityecture and OS, - // and the go-autorest version. - defaultUserAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s", - runtime.Version(), - runtime.GOARCH, - runtime.GOOS, - Version(), - ) - // StatusCodesForRetry are a defined group of status code for which the client will retry StatusCodesForRetry = []int{ http.StatusRequestTimeout, // 408 @@ -179,7 +171,7 @@ func NewClientWithUserAgent(ua string) Client { PollingDuration: DefaultPollingDuration, RetryAttempts: DefaultRetryAttempts, RetryDuration: DefaultRetryDuration, - UserAgent: defaultUserAgent, + UserAgent: version.UserAgent(), } c.Sender = c.sender() c.AddToUserAgent(ua) diff --git a/vendor/github.com/Azure/go-autorest/autorest/sender.go b/vendor/github.com/Azure/go-autorest/autorest/sender.go index b4f76232..cacbd815 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/sender.go +++ b/vendor/github.com/Azure/go-autorest/autorest/sender.go @@ -223,6 +223,10 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se return resp, err } resp, err = s.Do(rr.Request()) + // if the error isn't temporary don't bother retrying + if err != nil && !IsTemporaryNetworkError(err) { + return nil, err + } // we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication // resp and err will both have a value, so in this case we don't want to retry as it will never succeed. if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) { diff --git a/vendor/github.com/Azure/go-autorest/autorest/utility.go b/vendor/github.com/Azure/go-autorest/autorest/utility.go index afb3e4e1..bfddd90b 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/utility.go +++ b/vendor/github.com/Azure/go-autorest/autorest/utility.go @@ -20,6 +20,7 @@ import ( "encoding/xml" "fmt" "io" + "net" "net/http" "net/url" "reflect" @@ -216,3 +217,12 @@ func IsTokenRefreshError(err error) bool { } return false } + +// IsTemporaryNetworkError returns true if the specified error is a temporary network error or false +// if it's not. If the error doesn't implement the net.Error interface the return value is true. +func IsTemporaryNetworkError(err error) bool { + if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) { + return true + } + return false +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/version.go b/vendor/github.com/Azure/go-autorest/autorest/version.go index 4ad7754a..3c645154 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/version.go +++ b/vendor/github.com/Azure/go-autorest/autorest/version.go @@ -1,5 +1,7 @@ package autorest +import "github.com/Azure/go-autorest/version" + // Copyright 2017 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,5 +18,5 @@ package autorest // Version returns the semantic version (see http://semver.org). func Version() string { - return "v10.5.0" + return version.Number } diff --git a/vendor/github.com/Azure/go-autorest/version/version.go b/vendor/github.com/Azure/go-autorest/version/version.go new file mode 100644 index 00000000..10d1df64 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/version/version.go @@ -0,0 +1,37 @@ +package version + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "fmt" + "runtime" +) + +// Number contains the semantic version of this SDK. +const Number = "v10.14.0" + +var ( + userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s", + runtime.Version(), + runtime.GOARCH, + runtime.GOOS, + Number, + ) +) + +// UserAgent returns a string containing the Go version, system archityecture and OS, and the go-autorest version. +func UserAgent() string { + return userAgent +}