Azure driver retry fix (#4576)

This commit is contained in:
Milos Gajdos
2025-03-14 10:20:25 -07:00
committed by GitHub
15 changed files with 636 additions and 156 deletions

View File

@@ -37,7 +37,8 @@ jobs:
- 1.23.6
target:
- test-coverage
- test-cloud-storage
- test-cloud-storage # TODO: rename to test-s3-storage
- test-azure-storage
steps:
-
name: Checkout

View File

@@ -161,6 +161,27 @@ start-e2e-s3-env: ## starts E2E S3 storage test environment (S3, Redis, registry
stop-e2e-s3-env: ## stops E2E S3 storage test environment (S3, Redis, registry)
$(COMPOSE) -f tests/docker-compose-e2e-cloud-storage.yml down
.PHONY: test-azure-storage
test-azure-storage: start-azure-storage run-azure-tests stop-azure-storage ## run Azure storage driver tests
.PHONY: start-azure-storage
start-azure-storage: ## start local Azure storage (Azurite)
$(COMPOSE) -f tests/docker-compose-azure-blob-store.yaml up azurite azurite-init -d
.PHONY: stop-azure-storage
stop-azure-storage: ## stop local Azure storage (minio)
$(COMPOSE) -f tests/docker-compose-azure-blob-store.yaml down
.PHONY: run-azure-tests
run-azure-tests: start-azure-storage ## run Azure storage driver integration tests
AZURE_SKIP_VERIFY=true \
AZURE_STORAGE_CREDENTIALS_TYPE="shared_key" \
AZURE_STORAGE_ACCOUNT_NAME=devstoreaccount1 \
AZURE_STORAGE_ACCOUNT_KEY="Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" \
AZURE_STORAGE_CONTAINER=containername \
AZURE_SERVICE_URL="https://127.0.0.1:10000/devstoreaccount1" \
go test ${TESTFLAGS} -count=1 ./registry/storage/driver/azure/...
##@ Validate
lint: ## run all linters

View File

@@ -120,8 +120,8 @@ storage:
clientid: client_id_string
tenantid: tenant_id_string
secret: secret_string
copy_status_poll_max_retry: 10
copy_status_poll_delay: 100ms
max_retries: 10
retry_delay: 100ms
gcs:
bucket: bucketname
keyfile: /path/to/keyfile

View File

@@ -13,11 +13,26 @@ An implementation of the `storagedriver.StorageDriver` interface which uses [Mic
| `accountname` | yes | Name of the Azure Storage Account. |
| `accountkey` | yes | Primary or Secondary Key for the Storage Account. |
| `container` | yes | Name of the Azure root storage container in which all registry data is stored. Must comply the storage container name [requirements](https://docs.microsoft.com/rest/api/storageservices/fileservices/naming-and-referencing-containers--blobs--and-metadata). For example, if your url is `https://myaccount.blob.core.windows.net/myblob` use the container value of `myblob`.|
| `credentials` | yes | Azure credentials used to authenticate with Azure blob storage. |
| `rootdirectory` | no | This is a prefix that is applied to all Azure keys to allow you to segment data in your container if necessary. |
| `realm` | no | Domain name suffix for the Storage Service API endpoint. For example realm for "Azure in China" would be `core.chinacloudapi.cn` and realm for "Azure Government" would be `core.usgovcloudapi.net`. By default, this is `core.windows.net`. |
| `copy_status_poll_max_retry` | no | Max retry number for polling of copy operation status. Retries use a simple backoff algorithm where each retry number is multiplied by `copy_status_poll_delay`, and this number is used as the delay. Set to -1 to disable retries and abort if the copy does not complete immediately. Defaults to 5. |
| `copy_status_poll_delay` | no | Time to wait between retries for polling of copy operation status. This time is multiplied by N on each retry, where N is the retry number. Defaults to 100ms |
| `max_retries` | no | Max retries for driver operation status. Retries use a simple backoff algorithm where each retry number is multiplied by `retry_delay`, and this number is used as the delay. Set to -1 to disable retries and abort if the copy does not complete immediately. Defaults to 5. |
| `retry_delay` | no | Time to wait between retries for driver operation status. This time is multiplied by N on each retry, where N is the retry number. Defaults to 100ms |
### Credentials
| Parameter | Required | Description |
|:-----------------------------------|:---------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `type` | yes | Azure credentials used to authenticate with Azure blob storage (`client_secret`, `shared_key`, `default_credentials`). |
| `clientid` | yes | The unique application ID of this application in your directory. |
| `tenantid` | yes | Azure Active Directorys global unique identifier. |
| `secret` | yes | A secret string that the application uses to prove its identity when requesting a token. |
* `client_secret`: [used for token euthentication](https://learn.microsoft.com/en-us/azure/developer/go/sdk/authentication/authentication-overview#advantages-of-token-based-authentication)
* `shared_key`: used for shared key credentials authentication (read more [here](https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key))
* `default_credentials`: [default Azure credential authentication](https://learn.microsoft.com/en-us/azure/developer/go/sdk/authentication/authentication-overview#defaultazurecredential)
## Related information
* To get information about Azure blob storage [the offical docs](https://azure.microsoft.com/en-us/services/storage/).

View File

@@ -2,10 +2,14 @@ package azure
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
@@ -45,49 +49,111 @@ type azureClient struct {
signer signer
}
func newAzureClient(params *Parameters) (*azureClient, error) {
if params.AccountKey != "" {
cred, err := azblob.NewSharedKeyCredential(params.AccountName, params.AccountKey)
if err != nil {
return nil, err
}
client, err := azblob.NewClientWithSharedKeyCredential(params.ServiceURL, cred, nil)
if err != nil {
return nil, err
}
signer := &sharedKeySigner{
cred: cred,
}
return &azureClient{
container: params.Container,
client: client,
signer: signer,
}, nil
func newClient(params *DriverParameters) (*azureClient, error) {
switch params.Credentials.Type {
case CredentialsTypeClientSecret:
return newTokenClient(params)
case CredentialsTypeSharedKey, CredentialsTypeDefault:
return newSharedKeyCredentialsClient(params)
}
return nil, fmt.Errorf("invalid credentials type: %q", params.Credentials.Type)
}
var cred azcore.TokenCredential
var err error
if params.Credentials.Type == "client_secret" {
func newTokenClient(params *DriverParameters) (*azureClient, error) {
var (
cred azcore.TokenCredential
err error
)
switch params.Credentials.Type {
case CredentialsTypeClientSecret:
creds := &params.Credentials
if cred, err = azidentity.NewClientSecretCredential(creds.TenantID, creds.ClientID, creds.Secret, nil); err != nil {
return nil, err
cred, err = azidentity.NewClientSecretCredential(creds.TenantID, creds.ClientID, creds.Secret, nil)
if err != nil {
return nil, fmt.Errorf("client secret credentials: %v", err)
}
default:
cred, err = azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, fmt.Errorf("default credentials: %v", err)
}
} else if cred, err = azidentity.NewDefaultAzureCredential(nil); err != nil {
return nil, err
}
client, err := azblob.NewClient(params.ServiceURL, cred, nil)
azBlobOpts := &azblob.ClientOptions{
ClientOptions: azcore.ClientOptions{
PerRetryPolicies: []policy.Policy{newRetryNotificationPolicy()},
Logging: policy.LogOptions{
AllowedHeaders: []string{
"x-ms-error-code",
"Retry-After",
"Retry-After-Ms",
"If-Match",
"x-ms-blob-condition-appendpos",
},
AllowedQueryParams: []string{"comp"},
},
},
}
if params.SkipVerify {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
azBlobOpts.Transport = &http.Client{
Transport: httpTransport,
}
}
client, err := azblob.NewClient(params.ServiceURL, cred, azBlobOpts)
if err != nil {
return nil, err
}
signer := &clientTokenSigner{
client: client,
cred: cred,
return nil, fmt.Errorf("new azure token client: %v", err)
}
return &azureClient{
container: params.Container,
client: client,
signer: signer,
signer: &clientTokenSigner{
client: client,
cred: cred,
},
}, nil
}
func newSharedKeyCredentialsClient(params *DriverParameters) (*azureClient, error) {
cred, err := azblob.NewSharedKeyCredential(params.AccountName, params.AccountKey)
if err != nil {
return nil, fmt.Errorf("shared key credentials: %v", err)
}
azBlobOpts := &azblob.ClientOptions{
ClientOptions: azcore.ClientOptions{
PerRetryPolicies: []policy.Policy{newRetryNotificationPolicy()},
Logging: policy.LogOptions{
AllowedHeaders: []string{
"x-ms-error-code",
"Retry-After",
"Retry-After-Ms",
"If-Match",
"x-ms-blob-condition-appendpos",
},
AllowedQueryParams: []string{"comp"},
},
},
}
if params.SkipVerify {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
azBlobOpts.Transport = &http.Client{
Transport: httpTransport,
}
}
client, err := azblob.NewClientWithSharedKeyCredential(params.ServiceURL, cred, azBlobOpts)
if err != nil {
return nil, fmt.Errorf("new azure client with shared credentials: %v", err)
}
return &azureClient{
container: params.Container,
client: client,
signer: &sharedKeySigner{
cred: cred,
},
}, nil
}

View File

@@ -6,36 +6,58 @@ import (
"bufio"
"bytes"
"context"
"crypto/md5"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync/atomic"
"time"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/base"
"github.com/distribution/distribution/v3/registry/storage/driver/factory"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/appendblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
)
func init() {
factory.Register(driverName, &azureDriverFactory{})
}
var ErrCorruptedData = errors.New("corrupted data found in the uploaded data")
const (
driverName = "azure"
maxChunkSize = 4 * 1024 * 1024
)
type azureDriverFactory struct{}
func (factory *azureDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
params, err := NewParameters(parameters)
if err != nil {
return nil, err
}
return New(ctx, params)
}
var _ storagedriver.StorageDriver = &driver{}
type driver struct {
azClient *azureClient
client *container.Client
rootDirectory string
copyStatusPollMaxRetry int
copyStatusPollDelay time.Duration
azClient *azureClient
client *container.Client
rootDirectory string
maxRetries int
retryDelay time.Duration
}
type baseEmbed struct {
@@ -48,39 +70,25 @@ type Driver struct {
baseEmbed
}
func init() {
factory.Register(driverName, &azureDriverFactory{})
}
type azureDriverFactory struct{}
func (factory *azureDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
params, err := NewParameters(parameters)
if err != nil {
return nil, err
}
return New(ctx, params)
}
// New constructs a new Driver from parameters
func New(ctx context.Context, params *Parameters) (*Driver, error) {
azClient, err := newAzureClient(params)
func New(ctx context.Context, params *DriverParameters) (*Driver, error) {
azClient, err := newClient(params)
if err != nil {
return nil, err
}
copyStatusPollDelay, err := time.ParseDuration(params.CopyStatusPollDelay)
retryDelay, err := time.ParseDuration(params.RetryDelay)
if err != nil {
return nil, err
}
client := azClient.ContainerClient()
d := &driver{
azClient: azClient,
client: client,
rootDirectory: params.RootDirectory,
copyStatusPollMaxRetry: params.CopyStatusPollMaxRetry,
copyStatusPollDelay: copyStatusPollDelay,
azClient: azClient,
client: client,
rootDirectory: params.RootDirectory,
maxRetries: params.MaxRetries,
retryDelay: retryDelay,
}
return &Driver{
baseEmbed: baseEmbed{
@@ -141,9 +149,32 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
}
}
// TODO(milosgajdos): should we set some concurrency options on UploadBuffer
_, err = d.client.NewBlockBlobClient(blobName).UploadBuffer(ctx, contents, nil)
return err
// Always create as AppendBlob
appendBlobRef := d.client.NewAppendBlobClient(blobName)
if _, err := appendBlobRef.Create(ctx, nil); err != nil {
return fmt.Errorf("failed to create append blob: %v", err)
}
// If we have content, append it
if len(contents) > 0 {
// Write in chunks of maxChunkSize otherwise Azure can barf
// when writing large piece of data in one sot:
// RESPONSE 413: 413 The uploaded entity blob is too large.
for offset := 0; offset < len(contents); offset += maxChunkSize {
end := offset + maxChunkSize
if end > len(contents) {
end = len(contents)
}
chunk := contents[offset:end]
_, err := appendBlobRef.AppendBlock(ctx, streaming.NopCloser(bytes.NewReader(chunk)), nil)
if err != nil {
return fmt.Errorf("failed to append content: %v", err)
}
}
}
return nil
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
@@ -195,6 +226,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendMode bool) (stor
}
blobExists = false
}
eTag := props.ETag
var size int64
if blobExists {
@@ -204,20 +236,27 @@ func (d *driver) Writer(ctx context.Context, path string, appendMode bool) (stor
}
size = *props.ContentLength
} else {
if _, err := blobRef.Delete(ctx, nil); err != nil {
return nil, err
if _, err := blobRef.Delete(ctx, nil); err != nil && !is404(err) {
return nil, fmt.Errorf("deleting existing blob before write: %w", err)
}
res, err := d.client.NewAppendBlobClient(blobName).Create(ctx, nil)
if err != nil {
return nil, fmt.Errorf("creating new append blob: %w", err)
}
eTag = res.ETag
}
} else {
if appendMode {
return nil, storagedriver.PathNotFoundError{Path: path}
return nil, storagedriver.PathNotFoundError{Path: path, DriverName: driverName}
}
if _, err = d.client.NewAppendBlobClient(blobName).Create(ctx, nil); err != nil {
return nil, err
res, err := d.client.NewAppendBlobClient(blobName).Create(ctx, nil)
if err != nil {
return nil, fmt.Errorf("creating new append blob: %w", err)
}
eTag = res.ETag
}
return d.newWriter(ctx, blobName, size), nil
return d.newWriter(ctx, blobName, size, eTag), nil
}
// Stat retrieves the FileInfo for the given path, including the current size
@@ -303,22 +342,21 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
sourceBlobURL, err := d.signBlobURL(ctx, sourcePath)
if err != nil {
return err
}
srcBlobRef := d.client.NewBlobClient(d.blobName(sourcePath))
sourceBlobURL := srcBlobRef.URL()
destBlobRef := d.client.NewBlockBlobClient(d.blobName(destPath))
resp, err := destBlobRef.StartCopyFromURL(ctx, sourceBlobURL, nil)
if err != nil {
if is404(err) {
return storagedriver.PathNotFoundError{Path: sourcePath}
return storagedriver.PathNotFoundError{Path: sourcePath, DriverName: "azure"}
}
return err
}
copyStatus := *resp.CopyStatus
if d.copyStatusPollMaxRetry == -1 && copyStatus == blob.CopyStatusTypePending {
if d.maxRetries == -1 && copyStatus == blob.CopyStatusTypePending {
if _, err := destBlobRef.AbortCopyFromURL(ctx, *resp.CopyID, nil); err != nil {
return err
}
@@ -332,7 +370,7 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e
return err
}
if retryCount >= d.copyStatusPollMaxRetry {
if retryCount >= d.maxRetries {
if _, err := destBlobRef.AbortCopyFromURL(ctx, *props.CopyID, nil); err != nil {
return err
}
@@ -348,7 +386,7 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e
}
if copyStatus == blob.CopyStatusTypePending {
time.Sleep(d.copyStatusPollDelay * time.Duration(retryCount))
time.Sleep(d.retryDelay * time.Duration(retryCount))
}
retryCount++
}
@@ -506,25 +544,30 @@ var _ storagedriver.FileWriter = &writer{}
type writer struct {
driver *driver
path string
size int64
size *atomic.Int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(ctx context.Context, path string, size int64) storagedriver.FileWriter {
return &writer{
func (d *driver) newWriter(ctx context.Context, path string, size int64, eTag *azcore.ETag) storagedriver.FileWriter {
w := &writer{
driver: d,
path: path,
size: size,
// TODO(milosgajdos): I'm not sure about the maxChunkSize
bw: bufio.NewWriterSize(&blockWriter{
ctx: ctx,
client: d.client,
path: path,
}, maxChunkSize),
size: new(atomic.Int64),
}
w.size.Store(size)
bw := bufio.NewWriterSize(&blockWriter{
ctx: ctx,
client: d.client,
path: path,
size: w.size,
maxRetries: int32(d.maxRetries),
eTag: eTag,
}, maxChunkSize)
w.bw = bw
return w
}
func (w *writer) Write(p []byte) (int, error) {
@@ -537,12 +580,11 @@ func (w *writer) Write(p []byte) (int, error) {
}
n, err := w.bw.Write(p)
w.size += int64(n)
return n, err
}
func (w *writer) Size() int64 {
return w.size
return w.size.Load()
}
func (w *writer) Close() error {
@@ -578,18 +620,148 @@ func (w *writer) Commit(ctx context.Context) error {
}
type blockWriter struct {
// We construct transient blockWriter objects to encapsulate a write
// and need to keep the context passed in to the original FileWriter.Write
ctx context.Context
client *container.Client
path string
client *container.Client
path string
maxRetries int32
ctx context.Context
size *atomic.Int64
eTag *azcore.ETag
}
func (bw *blockWriter) Write(p []byte) (int, error) {
blobRef := bw.client.NewAppendBlobClient(bw.path)
_, err := blobRef.AppendBlock(bw.ctx, streaming.NopCloser(bytes.NewReader(p)), nil)
if err != nil {
return 0, err
appendBlobRef := bw.client.NewAppendBlobClient(bw.path)
n := 0
offsetRetryCount := int32(0)
for n < len(p) {
appendPos := bw.size.Load()
chunkSize := min(maxChunkSize, len(p)-n)
timeoutFromCtx := false
ctxTimeoutNotify := withTimeoutNotification(bw.ctx, &timeoutFromCtx)
resp, err := appendBlobRef.AppendBlock(
ctxTimeoutNotify,
streaming.NopCloser(bytes.NewReader(p[n:n+chunkSize])),
&appendblob.AppendBlockOptions{
AppendPositionAccessConditions: &appendblob.AppendPositionAccessConditions{
AppendPosition: to.Ptr(appendPos),
},
AccessConditions: &blob.AccessConditions{
ModifiedAccessConditions: &blob.ModifiedAccessConditions{
IfMatch: bw.eTag,
},
},
},
)
if err == nil {
n += chunkSize // number of bytes uploaded in this call to Write()
bw.eTag = resp.ETag
bw.size.Add(int64(chunkSize)) // total size of the blob in the backend
continue
}
appendposFailed := bloberror.HasCode(err, bloberror.AppendPositionConditionNotMet)
etagFailed := bloberror.HasCode(err, bloberror.ConditionNotMet)
if !(appendposFailed || etagFailed) || !timeoutFromCtx {
// Error was not caused by an operation timeout, abort!
return n, fmt.Errorf("appending blob: %w", err)
}
if offsetRetryCount >= bw.maxRetries {
return n, fmt.Errorf("max number of retries (%d) reached while handling backend operation timeout", bw.maxRetries)
}
correctlyUploadedBytes, newEtag, err := bw.chunkUploadVerify(appendPos, p[n:n+chunkSize])
if err != nil {
return n, fmt.Errorf("failed handling operation timeout during blob append: %w", err)
}
bw.eTag = newEtag
if correctlyUploadedBytes == 0 {
offsetRetryCount++
continue
}
offsetRetryCount = 0
// MD5 is correct, data was uploaded. Let's bump the counters and
// continue with the upload
n += int(correctlyUploadedBytes) // number of bytes uploaded in this call to Write()
bw.size.Add(correctlyUploadedBytes) // total size of the blob in the backend
}
return len(p), nil
return n, nil
}
// NOTE: this is more or less copy-pasta from the GitLab fix introduced by @vespian
// https://gitlab.com/gitlab-org/container-registry/-/commit/959132477ef719249270b87ce2a7a05abcd6e1ed?merge_request_iid=2059
func (bw *blockWriter) chunkUploadVerify(appendPos int64, chunk []byte) (int64, *azcore.ETag, error) {
// NOTE(prozlach): We need to see if the chunk uploaded or not. As per
// the documentation, the operation __might__ have succeeded. There are
// three options:
// * chunk did not upload, the file size will be the same as bw.size.
// In this case we simply need to re-upload the last chunk
// * chunk or part of it was uploaded - we need to verify the contents
// of what has been uploaded with MD5 hash and either:
// * MD5 is ok - let's continue uploading data starting from the next
// chunk
// * MD5 is not OK - we have garbadge at the end of the file and
// AppendBlock supports only appending, we need to abort and return
// permament error to the caller.
blobRef := bw.client.NewBlobClient(bw.path)
props, err := blobRef.GetProperties(bw.ctx, nil)
if err != nil {
return 0, nil, fmt.Errorf("determining the end of the blob: %v", err)
}
if props.ContentLength == nil {
return 0, nil, fmt.Errorf("ContentLength in blob properties is missing in reply: %v", err)
}
reuploadedBytes := *props.ContentLength - appendPos
if reuploadedBytes == 0 {
// NOTE(prozlach): This should never happen really and is here only as
// a precaution in case something changes in the future. The idea is
// that if the HTTP call did not succed and nothing was uploaded, then
// this code path is not going to be triggered as there will be no
// AppendPos condition violation during the retry. OTOH, if the write
// succeeded even partially, then the reuploadedBytes will be greater
// than zero.
return 0, props.ETag, nil
}
resp, err := blobRef.DownloadStream(
bw.ctx,
&blob.DownloadStreamOptions{
Range: blob.HTTPRange{Offset: appendPos, Count: reuploadedBytes},
RangeGetContentMD5: to.Ptr(true), // we always upload <= 4MiB (i.e the maxChunkSize), so we can try to offload the MD5 calculation to azure
},
)
if err != nil {
return 0, nil, fmt.Errorf("determining the MD5 of the upload blob chunk: %v", err)
}
var uploadedMD5 []byte
// If upstream makes this extra check, then let's be paranoid too.
if len(resp.ContentMD5) > 0 {
uploadedMD5 = resp.ContentMD5
} else {
// compute md5
body := resp.NewRetryReader(bw.ctx, &blob.RetryReaderOptions{MaxRetries: bw.maxRetries})
h := md5.New() // nolint: gosec // ok for content verification
_, err = io.Copy(h, body)
// nolint:errcheck
defer body.Close()
if err != nil {
return 0, nil, fmt.Errorf("calculating the MD5 of the uploaded blob chunk: %v", err)
}
uploadedMD5 = h.Sum(nil)
}
h := md5.New() // nolint: gosec // ok for content verification
if _, err = io.Copy(h, bytes.NewReader(chunk)); err != nil {
return 0, nil, fmt.Errorf("calculating the MD5 of the local blob chunk: %v", err)
}
localMD5 := h.Sum(nil)
if !bytes.Equal(uploadedMD5, localMD5) {
return 0, nil, fmt.Errorf("verifying contents of the uploaded blob chunk: %v", ErrCorruptedData)
}
return reuploadedBytes, resp.ETag, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"math/rand"
"os"
"strconv"
"strings"
"testing"
@@ -12,23 +13,29 @@ import (
)
const (
envAccountName = "AZURE_STORAGE_ACCOUNT_NAME"
envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY"
envContainer = "AZURE_STORAGE_CONTAINER"
envRealm = "AZURE_STORAGE_REALM"
envRootDirectory = "AZURE_ROOT_DIRECTORY"
envCredentialsType = "AZURE_STORAGE_CREDENTIALS_TYPE"
envAccountName = "AZURE_STORAGE_ACCOUNT_NAME"
envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY"
envContainer = "AZURE_STORAGE_CONTAINER"
envServiceURL = "AZURE_SERVICE_URL"
envRootDirectory = "AZURE_ROOT_DIRECTORY"
envSkipVerify = "AZURE_SKIP_VERIFY"
)
var azureDriverConstructor func() (storagedriver.StorageDriver, error)
var skipCheck func(tb testing.TB)
var (
azureDriverConstructor func() (storagedriver.StorageDriver, error)
skipCheck func(tb testing.TB)
)
func init() {
var (
accountName string
accountKey string
container string
realm string
rootDirectory string
accountName string
accountKey string
container string
serviceURL string
rootDirectory string
credentialsType string
skipVerify string
)
config := []struct {
@@ -39,8 +46,10 @@ func init() {
{envAccountName, &accountName, false},
{envAccountKey, &accountKey, true},
{envContainer, &container, true},
{envRealm, &realm, true},
{envServiceURL, &serviceURL, false},
{envRootDirectory, &rootDirectory, true},
{envCredentialsType, &credentialsType, true},
{envSkipVerify, &skipVerify, true},
}
missing := []string{}
@@ -51,13 +60,24 @@ func init() {
}
}
skipVerifyBool, err := strconv.ParseBool(skipVerify)
if err != nil {
// NOTE(milosgajdos): if we fail to parse AZURE_SKIP_VERIFY
// we default to verifying TLS certs
skipVerifyBool = false
}
azureDriverConstructor = func() (storagedriver.StorageDriver, error) {
parameters := map[string]interface{}{
"container": container,
"accountname": accountName,
"accountkey": accountKey,
"realm": realm,
"serviceurl": serviceURL,
"rootdirectory": rootDirectory,
"credentials": map[string]any{
"type": credentialsType,
},
"skipverify": skipVerifyBool,
}
params, err := NewParameters(parameters)
if err != nil {
@@ -78,7 +98,11 @@ func init() {
func TestAzureDriverSuite(t *testing.T) {
skipCheck(t)
testsuites.Driver(t, azureDriverConstructor)
skipVerify, err := strconv.ParseBool(os.Getenv(envSkipVerify))
if err != nil {
skipVerify = false
}
testsuites.Driver(t, azureDriverConstructor, skipVerify)
}
func BenchmarkAzureDriverSuite(b *testing.B) {
@@ -161,26 +185,26 @@ func TestParamParsing(t *testing.T) {
}
}
input := []map[string]interface{}{
{"accountname": "acc1", "accountkey": "k1", "container": "c1", "copy_status_poll_max_retry": 1, "copy_status_poll_delay": "10ms"},
{"accountname": "acc1", "accountkey": "k1", "container": "c1", "max_retries": 1, "retry_delay": "10ms"},
{"accountname": "acc1", "container": "c1", "credentials": map[string]interface{}{"type": "default"}},
{"accountname": "acc1", "container": "c1", "credentials": map[string]interface{}{"type": "client_secret", "clientid": "c1", "tenantid": "t1", "secret": "s1"}},
}
expecteds := []Parameters{
expecteds := []DriverParameters{
{
Container: "c1", AccountName: "acc1", AccountKey: "k1",
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
CopyStatusPollMaxRetry: 1, CopyStatusPollDelay: "10ms",
MaxRetries: 1, RetryDelay: "10ms",
},
{
Container: "c1", AccountName: "acc1", Credentials: Credentials{Type: "default"},
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
CopyStatusPollMaxRetry: 5, CopyStatusPollDelay: "100ms",
MaxRetries: 5, RetryDelay: "100ms",
},
{
Container: "c1", AccountName: "acc1",
Credentials: Credentials{Type: "client_secret", ClientID: "c1", TenantID: "t1", Secret: "s1"},
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
CopyStatusPollMaxRetry: 5, CopyStatusPollDelay: "100ms",
MaxRetries: 5, RetryDelay: "100ms",
},
}
for i, expected := range expecteds {

View File

@@ -8,33 +8,42 @@ import (
)
const (
defaultRealm = "core.windows.net"
defaultCopyStatusPollMaxRetry = 5
defaultCopyStatusPollDelay = "100ms"
defaultRealm = "core.windows.net"
defaultMaxRetries = 5
defaultRetryDelay = "100ms"
)
type CredentialsType string
const (
CredentialsTypeClientSecret = "client_secret"
CredentialsTypeSharedKey = "shared_key"
CredentialsTypeDefault = "default_credentials"
)
type Credentials struct {
Type string `mapstructure:"type"`
ClientID string `mapstructure:"clientid"`
TenantID string `mapstructure:"tenantid"`
Secret string `mapstructure:"secret"`
Type CredentialsType `mapstructure:"type"`
ClientID string `mapstructure:"clientid"`
TenantID string `mapstructure:"tenantid"`
Secret string `mapstructure:"secret"`
}
type Parameters struct {
Container string `mapstructure:"container"`
AccountName string `mapstructure:"accountname"`
AccountKey string `mapstructure:"accountkey"`
Credentials Credentials `mapstructure:"credentials"`
ConnectionString string `mapstructure:"connectionstring"`
Realm string `mapstructure:"realm"`
RootDirectory string `mapstructure:"rootdirectory"`
ServiceURL string `mapstructure:"serviceurl"`
CopyStatusPollMaxRetry int `mapstructure:"copy_status_poll_max_retry"`
CopyStatusPollDelay string `mapstructure:"copy_status_poll_delay"`
type DriverParameters struct {
Credentials Credentials `mapstructure:"credentials"`
Container string `mapstructure:"container"`
AccountName string `mapstructure:"accountname"`
AccountKey string `mapstructure:"accountkey"`
ConnectionString string `mapstructure:"connectionstring"`
Realm string `mapstructure:"realm"`
RootDirectory string `mapstructure:"rootdirectory"`
ServiceURL string `mapstructure:"serviceurl"`
MaxRetries int `mapstructure:"max_retries"`
RetryDelay string `mapstructure:"retry_delay"`
SkipVerify bool `mapstructure:"skipverify"`
}
func NewParameters(parameters map[string]interface{}) (*Parameters, error) {
params := Parameters{
func NewParameters(parameters map[string]interface{}) (*DriverParameters, error) {
params := DriverParameters{
Realm: defaultRealm,
}
if err := mapstructure.Decode(parameters, &params); err != nil {
@@ -49,11 +58,11 @@ func NewParameters(parameters map[string]interface{}) (*Parameters, error) {
if params.ServiceURL == "" {
params.ServiceURL = fmt.Sprintf("https://%s.blob.%s", params.AccountName, params.Realm)
}
if params.CopyStatusPollMaxRetry == 0 {
params.CopyStatusPollMaxRetry = defaultCopyStatusPollMaxRetry
if params.MaxRetries == 0 {
params.MaxRetries = defaultMaxRetries
}
if params.CopyStatusPollDelay == "" {
params.CopyStatusPollDelay = defaultCopyStatusPollDelay
if params.RetryDelay == "" {
params.RetryDelay = defaultRetryDelay
}
return &params, nil
}

View File

@@ -0,0 +1,91 @@
package azure
import (
"context"
"net/http"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
)
// Inspired by/credit goes to https://github.com/Azure/azure-storage-azcopy/blob/97ab7b92e766ad48965ac2933495dff1b04fb2a7/ste/xferRetryNotificationPolicy.go
type contextKey struct {
name string
}
var (
timeoutNotifyContextKey = contextKey{"timeoutNotify"}
retryNotifyContextKey = contextKey{"retryNotify"}
)
// retryNotificationReceiver should be implemented by code that wishes to be
// notified when a retry happens. Such code must register itself into the
// context, using withRetryNotification, so that the RetryNotificationPolicy
// can invoke the callback when necessary.
type retryNotificationReceiver interface {
RetryCallback()
}
// withTimeoutNotification returns a context that contains indication of a
// timeout. The retryNotificationPolicy will then set the timeout flag when a
// timeout happens
func withTimeoutNotification(ctx context.Context, timeout *bool) context.Context {
return context.WithValue(ctx, timeoutNotifyContextKey, timeout)
}
// withRetryNotifier returns a context that contains a retry notifier. The
// retryNotificationPolicy will then invoke the callback when a retry happens
func withRetryNotification(ctx context.Context, r retryNotificationReceiver) context.Context { // nolint: unused // may become useful at some point
return context.WithValue(ctx, retryNotifyContextKey, r)
}
// PolicyFunc is a type that implements the Policy interface.
// Use this type when implementing a stateless policy as a first-class function.
type PolicyFunc func(*policy.Request) (*http.Response, error)
// Do implements the Policy interface on policyFunc.
func (pf PolicyFunc) Do(req *policy.Request) (*http.Response, error) {
return pf(req)
}
func newRetryNotificationPolicy() policy.Policy {
getErrorCode := func(resp *http.Response) string {
// NOTE(prozlach): This is a hacky way to handle all possible cases of
// emitting error by the Azure backend.
// In theory we could look just at `x-ms-error-code` HTTP header, but
// in practice Azure SDK also looks at the body and decodes it as JSON
// or XML in case when the header is absent.
// So the idea is to piggy-back on the runtime.NewResponseError that
// will do the proper decoding for us and just return the ErrorCode
// field instead.
return runtime.NewResponseError(resp).(*azcore.ResponseError).ErrorCode
}
return PolicyFunc(func(req *policy.Request) (*http.Response, error) {
response, err := req.Next() // Make the request
if response == nil {
return nil, err
}
switch response.StatusCode {
case http.StatusServiceUnavailable:
// Grab the notification callback out of the context and, if its there, call it
if notifier, ok := req.Raw().Context().Value(retryNotifyContextKey).(retryNotificationReceiver); ok {
notifier.RetryCallback()
}
case http.StatusInternalServerError:
errorCodeHeader := getErrorCode(response)
if bloberror.Code(errorCodeHeader) != bloberror.OperationTimedOut {
break
}
if timeout, ok := req.Raw().Context().Value(timeoutNotifyContextKey).(*bool); ok {
*timeout = true
}
}
return response, err
})
}

View File

@@ -19,7 +19,7 @@ func newDriverConstructor(tb testing.TB) testsuites.DriverConstructor {
}
func TestFilesystemDriverSuite(t *testing.T) {
testsuites.Driver(t, newDriverConstructor(t))
testsuites.Driver(t, newDriverConstructor(t), false)
}
func BenchmarkFilesystemDriverSuite(b *testing.B) {

View File

@@ -93,7 +93,7 @@ func newDriverConstructor(tb testing.TB) testsuites.DriverConstructor {
func TestGCSDriverSuite(t *testing.T) {
skipCheck(t)
testsuites.Driver(t, newDriverConstructor(t))
testsuites.Driver(t, newDriverConstructor(t), false)
}
func BenchmarkGCSDriverSuite(b *testing.B) {

View File

@@ -12,7 +12,7 @@ func newDriverConstructor() (storagedriver.StorageDriver, error) {
}
func TestInMemoryDriverSuite(t *testing.T) {
testsuites.Driver(t, newDriverConstructor)
testsuites.Driver(t, newDriverConstructor, false)
}
func BenchmarkInMemoryDriverSuite(b *testing.B) {

View File

@@ -154,7 +154,11 @@ func newDriverConstructor(tb testing.TB) testsuites.DriverConstructor {
func TestS3DriverSuite(t *testing.T) {
skipCheck(t)
testsuites.Driver(t, newDriverConstructor(t))
skipVerify, err := strconv.ParseBool(os.Getenv("S3_SKIP_VERIFY"))
if err != nil {
skipVerify = false
}
testsuites.Driver(t, newDriverConstructor(t), skipVerify)
}
func BenchmarkS3DriverSuite(b *testing.B) {

View File

@@ -5,6 +5,7 @@ import (
"context"
crand "crypto/rand"
"crypto/sha256"
"crypto/tls"
"io"
"math/rand"
"net/http"
@@ -43,14 +44,16 @@ type DriverSuite struct {
Constructor DriverConstructor
Teardown DriverTeardown
storagedriver.StorageDriver
ctx context.Context
ctx context.Context
skipVerify bool
}
// Driver runs [DriverSuite] for the given [DriverConstructor].
func Driver(t *testing.T, driverConstructor DriverConstructor) {
func Driver(t *testing.T, driverConstructor DriverConstructor, skipVerify bool) {
suite.Run(t, &DriverSuite{
Constructor: driverConstructor,
ctx: context.Background(),
skipVerify: skipVerify,
})
}
@@ -739,7 +742,19 @@ func (suite *DriverSuite) TestRedirectURL() {
}
suite.Require().NoError(err)
response, err := http.Get(url)
client := http.DefaultClient
if suite.skipVerify {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
client = &http.Client{
Transport: httpTransport,
}
}
req, err := http.NewRequest(http.MethodGet, url, nil)
suite.Require().NoError(err)
response, err := client.Do(req)
suite.Require().NoError(err)
defer response.Body.Close()
@@ -753,7 +768,10 @@ func (suite *DriverSuite) TestRedirectURL() {
}
suite.Require().NoError(err)
response, err = http.Head(url)
req, err = http.NewRequest(http.MethodHead, url, nil)
suite.Require().NoError(err)
response, err = client.Do(req)
suite.Require().NoError(err)
defer response.Body.Close()
suite.Require().Equal(200, response.StatusCode)
@@ -1323,7 +1341,7 @@ func (suite *DriverSuite) writeReadCompareStreams(filename string, contents []by
var (
filenameChars = []byte("abcdefghijklmnopqrstuvwxyz0123456789")
separatorChars = []byte("._-")
separatorChars = []byte("-")
)
func randomPath(length int64) string {

View File

@@ -0,0 +1,59 @@
services:
cert-init:
image: alpine/mkcert
volumes:
- ./certs:/certs
entrypoint: /bin/sh
command: >
-c "
mkdir -p /certs &&
cd /certs &&
mkcert -install &&
mkcert 127.0.0.1 &&
chmod 644 /certs/127.0.0.1.pem /certs/127.0.0.1-key.pem
"
azurite:
image: mcr.microsoft.com/azure-storage/azurite
ports:
- "10000:10000"
volumes:
- ./certs:/workspace
command: >
azurite-blob
--blobHost 0.0.0.0
--oauth basic
--loose 0.0.0.0
--cert /workspace/127.0.0.1.pem
--key /workspace/127.0.0.1-key.pem
depends_on:
cert-init:
condition: service_completed_successfully
healthcheck:
# NOTE(milosgajdos): Azurite does not have a healtcheck endpoint
# so we are temporarilty working around it by using a simple node command.
# We want to make sure the API is up and running so we are deliberately
# ignoring that the healthcheck API request fails authorization check
test: [
"CMD",
"node",
"-e",
"const http = require('https'); const options = { hostname: '127.0.0.1', port: 10000, path: '/', method: 'GET', rejectUnauthorized: false }; const req = http.request(options, res => process.exit(0)); req.on('error', err => process.exit(1)); req.end();"
]
interval: 5s
timeout: 5s
retries: 10
azurite-init:
image: mcr.microsoft.com/azure-cli
depends_on:
azurite:
condition: service_healthy
volumes:
- ./certs:/certs
environment:
AZURE_STORAGE_CONNECTION_STRING: "DefaultEndpointsProtocol=https;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=https://azurite:10000/devstoreaccount1;"
AZURE_CLI_DISABLE_CONNECTION_VERIFICATION: "1"
entrypoint: >
/bin/bash -c "
az storage container create --name containername --connection-string \"$$AZURE_STORAGE_CONNECTION_STRING\" --debug"