From 2ffa1171c2ef75394b7241ddc2a74c2130341692 Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Fri, 21 Feb 2025 07:16:49 -0800 Subject: [PATCH] Azure driver fix * Make copy poll max retry, a global driver max retry * Get support for etags in Azure * Fix storage driver tests * Fix auth mess and update docs * Refactor Azure client and enable Azure storage tests We use Azurite for integration testing which requires TLS, so we had to figure out how to skip TLS verification when running tests locally: this required updating testsuites Driver and constructor due to TestRedirectURL sending GET and HEAD requests to remote storage which in this case is Azurite. Signed-off-by: Milos Gajdos --- .github/workflows/build.yml | 3 +- Makefile | 21 ++ docs/content/about/configuration.md | 4 +- docs/content/storage-drivers/azure.md | 19 +- .../driver/azure/{azure_auth.go => auth.go} | 130 ++++++-- registry/storage/driver/azure/azure.go | 304 ++++++++++++++---- registry/storage/driver/azure/azure_test.go | 64 ++-- registry/storage/driver/azure/parser.go | 57 ++-- .../driver/azure/retry_notification_policy.go | 91 ++++++ .../storage/driver/filesystem/driver_test.go | 2 +- registry/storage/driver/gcs/gcs_test.go | 2 +- .../storage/driver/inmemory/driver_test.go | 2 +- registry/storage/driver/s3-aws/s3_test.go | 6 +- .../storage/driver/testsuites/testsuites.go | 28 +- tests/docker-compose-azure-blob-store.yaml | 59 ++++ 15 files changed, 636 insertions(+), 156 deletions(-) rename registry/storage/driver/azure/{azure_auth.go => auth.go} (53%) create mode 100644 registry/storage/driver/azure/retry_notification_policy.go create mode 100644 tests/docker-compose-azure-blob-store.yaml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8eb1929da..c93662dde 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,7 +37,8 @@ jobs: - 1.23.6 target: - test-coverage - - test-cloud-storage + - test-cloud-storage # TODO: rename to test-s3-storage + - test-azure-storage steps: - name: Checkout diff --git a/Makefile b/Makefile index 97873bc2c..331885f8b 100644 --- a/Makefile +++ b/Makefile @@ -161,6 +161,27 @@ start-e2e-s3-env: ## starts E2E S3 storage test environment (S3, Redis, registry stop-e2e-s3-env: ## stops E2E S3 storage test environment (S3, Redis, registry) $(COMPOSE) -f tests/docker-compose-e2e-cloud-storage.yml down +.PHONY: test-azure-storage +test-azure-storage: start-azure-storage run-azure-tests stop-azure-storage ## run Azure storage driver tests + +.PHONY: start-azure-storage +start-azure-storage: ## start local Azure storage (Azurite) + $(COMPOSE) -f tests/docker-compose-azure-blob-store.yaml up azurite azurite-init -d + +.PHONY: stop-azure-storage +stop-azure-storage: ## stop local Azure storage (minio) + $(COMPOSE) -f tests/docker-compose-azure-blob-store.yaml down + +.PHONY: run-azure-tests +run-azure-tests: start-azure-storage ## run Azure storage driver integration tests + AZURE_SKIP_VERIFY=true \ + AZURE_STORAGE_CREDENTIALS_TYPE="shared_key" \ + AZURE_STORAGE_ACCOUNT_NAME=devstoreaccount1 \ + AZURE_STORAGE_ACCOUNT_KEY="Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" \ + AZURE_STORAGE_CONTAINER=containername \ + AZURE_SERVICE_URL="https://127.0.0.1:10000/devstoreaccount1" \ + go test ${TESTFLAGS} -count=1 ./registry/storage/driver/azure/... + ##@ Validate lint: ## run all linters diff --git a/docs/content/about/configuration.md b/docs/content/about/configuration.md index 1f83575d4..27f9f5906 100644 --- a/docs/content/about/configuration.md +++ b/docs/content/about/configuration.md @@ -120,8 +120,8 @@ storage: clientid: client_id_string tenantid: tenant_id_string secret: secret_string - copy_status_poll_max_retry: 10 - copy_status_poll_delay: 100ms + max_retries: 10 + retry_delay: 100ms gcs: bucket: bucketname keyfile: /path/to/keyfile diff --git a/docs/content/storage-drivers/azure.md b/docs/content/storage-drivers/azure.md index d2b6c54c1..71ede74e8 100644 --- a/docs/content/storage-drivers/azure.md +++ b/docs/content/storage-drivers/azure.md @@ -13,11 +13,26 @@ An implementation of the `storagedriver.StorageDriver` interface which uses [Mic | `accountname` | yes | Name of the Azure Storage Account. | | `accountkey` | yes | Primary or Secondary Key for the Storage Account. | | `container` | yes | Name of the Azure root storage container in which all registry data is stored. Must comply the storage container name [requirements](https://docs.microsoft.com/rest/api/storageservices/fileservices/naming-and-referencing-containers--blobs--and-metadata). For example, if your url is `https://myaccount.blob.core.windows.net/myblob` use the container value of `myblob`.| +| `credentials` | yes | Azure credentials used to authenticate with Azure blob storage. | +| `rootdirectory` | no | This is a prefix that is applied to all Azure keys to allow you to segment data in your container if necessary. | | `realm` | no | Domain name suffix for the Storage Service API endpoint. For example realm for "Azure in China" would be `core.chinacloudapi.cn` and realm for "Azure Government" would be `core.usgovcloudapi.net`. By default, this is `core.windows.net`. | -| `copy_status_poll_max_retry` | no | Max retry number for polling of copy operation status. Retries use a simple backoff algorithm where each retry number is multiplied by `copy_status_poll_delay`, and this number is used as the delay. Set to -1 to disable retries and abort if the copy does not complete immediately. Defaults to 5. | -| `copy_status_poll_delay` | no | Time to wait between retries for polling of copy operation status. This time is multiplied by N on each retry, where N is the retry number. Defaults to 100ms | +| `max_retries` | no | Max retries for driver operation status. Retries use a simple backoff algorithm where each retry number is multiplied by `retry_delay`, and this number is used as the delay. Set to -1 to disable retries and abort if the copy does not complete immediately. Defaults to 5. | +| `retry_delay` | no | Time to wait between retries for driver operation status. This time is multiplied by N on each retry, where N is the retry number. Defaults to 100ms | +### Credentials + +| Parameter | Required | Description | +|:-----------------------------------|:---------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `type` | yes | Azure credentials used to authenticate with Azure blob storage (`client_secret`, `shared_key`, `default_credentials`). | +| `clientid` | yes | The unique application ID of this application in your directory. | +| `tenantid` | yes | Azure Active Directory’s global unique identifier. | +| `secret` | yes | A secret string that the application uses to prove its identity when requesting a token. | + +* `client_secret`: [used for token euthentication](https://learn.microsoft.com/en-us/azure/developer/go/sdk/authentication/authentication-overview#advantages-of-token-based-authentication) +* `shared_key`: used for shared key credentials authentication (read more [here](https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key)) +* `default_credentials`: [default Azure credential authentication](https://learn.microsoft.com/en-us/azure/developer/go/sdk/authentication/authentication-overview#defaultazurecredential) + ## Related information * To get information about Azure blob storage [the offical docs](https://azure.microsoft.com/en-us/services/storage/). diff --git a/registry/storage/driver/azure/azure_auth.go b/registry/storage/driver/azure/auth.go similarity index 53% rename from registry/storage/driver/azure/azure_auth.go rename to registry/storage/driver/azure/auth.go index 228bcdf29..4d824aa43 100644 --- a/registry/storage/driver/azure/azure_auth.go +++ b/registry/storage/driver/azure/auth.go @@ -2,10 +2,14 @@ package azure import ( "context" + "crypto/tls" + "fmt" + "net/http" "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" @@ -45,49 +49,111 @@ type azureClient struct { signer signer } -func newAzureClient(params *Parameters) (*azureClient, error) { - if params.AccountKey != "" { - cred, err := azblob.NewSharedKeyCredential(params.AccountName, params.AccountKey) - if err != nil { - return nil, err - } - client, err := azblob.NewClientWithSharedKeyCredential(params.ServiceURL, cred, nil) - if err != nil { - return nil, err - } - signer := &sharedKeySigner{ - cred: cred, - } - return &azureClient{ - container: params.Container, - client: client, - signer: signer, - }, nil +func newClient(params *DriverParameters) (*azureClient, error) { + switch params.Credentials.Type { + case CredentialsTypeClientSecret: + return newTokenClient(params) + case CredentialsTypeSharedKey, CredentialsTypeDefault: + return newSharedKeyCredentialsClient(params) } + return nil, fmt.Errorf("invalid credentials type: %q", params.Credentials.Type) +} - var cred azcore.TokenCredential - var err error - if params.Credentials.Type == "client_secret" { +func newTokenClient(params *DriverParameters) (*azureClient, error) { + var ( + cred azcore.TokenCredential + err error + ) + + switch params.Credentials.Type { + case CredentialsTypeClientSecret: creds := ¶ms.Credentials - if cred, err = azidentity.NewClientSecretCredential(creds.TenantID, creds.ClientID, creds.Secret, nil); err != nil { - return nil, err + cred, err = azidentity.NewClientSecretCredential(creds.TenantID, creds.ClientID, creds.Secret, nil) + if err != nil { + return nil, fmt.Errorf("client secret credentials: %v", err) + } + default: + cred, err = azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, fmt.Errorf("default credentials: %v", err) } - } else if cred, err = azidentity.NewDefaultAzureCredential(nil); err != nil { - return nil, err } - client, err := azblob.NewClient(params.ServiceURL, cred, nil) + azBlobOpts := &azblob.ClientOptions{ + ClientOptions: azcore.ClientOptions{ + PerRetryPolicies: []policy.Policy{newRetryNotificationPolicy()}, + Logging: policy.LogOptions{ + AllowedHeaders: []string{ + "x-ms-error-code", + "Retry-After", + "Retry-After-Ms", + "If-Match", + "x-ms-blob-condition-appendpos", + }, + AllowedQueryParams: []string{"comp"}, + }, + }, + } + if params.SkipVerify { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + azBlobOpts.Transport = &http.Client{ + Transport: httpTransport, + } + } + client, err := azblob.NewClient(params.ServiceURL, cred, azBlobOpts) if err != nil { - return nil, err - } - signer := &clientTokenSigner{ - client: client, - cred: cred, + return nil, fmt.Errorf("new azure token client: %v", err) } + return &azureClient{ container: params.Container, client: client, - signer: signer, + signer: &clientTokenSigner{ + client: client, + cred: cred, + }, + }, nil +} + +func newSharedKeyCredentialsClient(params *DriverParameters) (*azureClient, error) { + cred, err := azblob.NewSharedKeyCredential(params.AccountName, params.AccountKey) + if err != nil { + return nil, fmt.Errorf("shared key credentials: %v", err) + } + azBlobOpts := &azblob.ClientOptions{ + ClientOptions: azcore.ClientOptions{ + PerRetryPolicies: []policy.Policy{newRetryNotificationPolicy()}, + Logging: policy.LogOptions{ + AllowedHeaders: []string{ + "x-ms-error-code", + "Retry-After", + "Retry-After-Ms", + "If-Match", + "x-ms-blob-condition-appendpos", + }, + AllowedQueryParams: []string{"comp"}, + }, + }, + } + if params.SkipVerify { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + azBlobOpts.Transport = &http.Client{ + Transport: httpTransport, + } + } + client, err := azblob.NewClientWithSharedKeyCredential(params.ServiceURL, cred, azBlobOpts) + if err != nil { + return nil, fmt.Errorf("new azure client with shared credentials: %v", err) + } + + return &azureClient{ + container: params.Container, + client: client, + signer: &sharedKeySigner{ + cred: cred, + }, }, nil } diff --git a/registry/storage/driver/azure/azure.go b/registry/storage/driver/azure/azure.go index 5b0b5f905..3d0393566 100644 --- a/registry/storage/driver/azure/azure.go +++ b/registry/storage/driver/azure/azure.go @@ -6,36 +6,58 @@ import ( "bufio" "bytes" "context" + "crypto/md5" + "errors" "fmt" "io" "net/http" "strings" + "sync/atomic" "time" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/base" "github.com/distribution/distribution/v3/registry/storage/driver/factory" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/appendblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" ) +func init() { + factory.Register(driverName, &azureDriverFactory{}) +} + +var ErrCorruptedData = errors.New("corrupted data found in the uploaded data") + const ( driverName = "azure" maxChunkSize = 4 * 1024 * 1024 ) +type azureDriverFactory struct{} + +func (factory *azureDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { + params, err := NewParameters(parameters) + if err != nil { + return nil, err + } + return New(ctx, params) +} + var _ storagedriver.StorageDriver = &driver{} type driver struct { - azClient *azureClient - client *container.Client - rootDirectory string - copyStatusPollMaxRetry int - copyStatusPollDelay time.Duration + azClient *azureClient + client *container.Client + rootDirectory string + maxRetries int + retryDelay time.Duration } type baseEmbed struct { @@ -48,39 +70,25 @@ type Driver struct { baseEmbed } -func init() { - factory.Register(driverName, &azureDriverFactory{}) -} - -type azureDriverFactory struct{} - -func (factory *azureDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { - params, err := NewParameters(parameters) - if err != nil { - return nil, err - } - return New(ctx, params) -} - // New constructs a new Driver from parameters -func New(ctx context.Context, params *Parameters) (*Driver, error) { - azClient, err := newAzureClient(params) +func New(ctx context.Context, params *DriverParameters) (*Driver, error) { + azClient, err := newClient(params) if err != nil { return nil, err } - copyStatusPollDelay, err := time.ParseDuration(params.CopyStatusPollDelay) + retryDelay, err := time.ParseDuration(params.RetryDelay) if err != nil { return nil, err } client := azClient.ContainerClient() d := &driver{ - azClient: azClient, - client: client, - rootDirectory: params.RootDirectory, - copyStatusPollMaxRetry: params.CopyStatusPollMaxRetry, - copyStatusPollDelay: copyStatusPollDelay, + azClient: azClient, + client: client, + rootDirectory: params.RootDirectory, + maxRetries: params.MaxRetries, + retryDelay: retryDelay, } return &Driver{ baseEmbed: baseEmbed{ @@ -141,9 +149,32 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e } } - // TODO(milosgajdos): should we set some concurrency options on UploadBuffer - _, err = d.client.NewBlockBlobClient(blobName).UploadBuffer(ctx, contents, nil) - return err + // Always create as AppendBlob + appendBlobRef := d.client.NewAppendBlobClient(blobName) + if _, err := appendBlobRef.Create(ctx, nil); err != nil { + return fmt.Errorf("failed to create append blob: %v", err) + } + + // If we have content, append it + if len(contents) > 0 { + // Write in chunks of maxChunkSize otherwise Azure can barf + // when writing large piece of data in one sot: + // RESPONSE 413: 413 The uploaded entity blob is too large. + for offset := 0; offset < len(contents); offset += maxChunkSize { + end := offset + maxChunkSize + if end > len(contents) { + end = len(contents) + } + + chunk := contents[offset:end] + _, err := appendBlobRef.AppendBlock(ctx, streaming.NopCloser(bytes.NewReader(chunk)), nil) + if err != nil { + return fmt.Errorf("failed to append content: %v", err) + } + } + } + + return nil } // Reader retrieves an io.ReadCloser for the content stored at "path" with a @@ -195,6 +226,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendMode bool) (stor } blobExists = false } + eTag := props.ETag var size int64 if blobExists { @@ -204,20 +236,27 @@ func (d *driver) Writer(ctx context.Context, path string, appendMode bool) (stor } size = *props.ContentLength } else { - if _, err := blobRef.Delete(ctx, nil); err != nil { - return nil, err + if _, err := blobRef.Delete(ctx, nil); err != nil && !is404(err) { + return nil, fmt.Errorf("deleting existing blob before write: %w", err) } + res, err := d.client.NewAppendBlobClient(blobName).Create(ctx, nil) + if err != nil { + return nil, fmt.Errorf("creating new append blob: %w", err) + } + eTag = res.ETag } } else { if appendMode { - return nil, storagedriver.PathNotFoundError{Path: path} + return nil, storagedriver.PathNotFoundError{Path: path, DriverName: driverName} } - if _, err = d.client.NewAppendBlobClient(blobName).Create(ctx, nil); err != nil { - return nil, err + res, err := d.client.NewAppendBlobClient(blobName).Create(ctx, nil) + if err != nil { + return nil, fmt.Errorf("creating new append blob: %w", err) } + eTag = res.ETag } - return d.newWriter(ctx, blobName, size), nil + return d.newWriter(ctx, blobName, size, eTag), nil } // Stat retrieves the FileInfo for the given path, including the current size @@ -303,22 +342,21 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) { // Move moves an object stored at sourcePath to destPath, removing the original // object. func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error { - sourceBlobURL, err := d.signBlobURL(ctx, sourcePath) - if err != nil { - return err - } + srcBlobRef := d.client.NewBlobClient(d.blobName(sourcePath)) + sourceBlobURL := srcBlobRef.URL() + destBlobRef := d.client.NewBlockBlobClient(d.blobName(destPath)) resp, err := destBlobRef.StartCopyFromURL(ctx, sourceBlobURL, nil) if err != nil { if is404(err) { - return storagedriver.PathNotFoundError{Path: sourcePath} + return storagedriver.PathNotFoundError{Path: sourcePath, DriverName: "azure"} } return err } copyStatus := *resp.CopyStatus - if d.copyStatusPollMaxRetry == -1 && copyStatus == blob.CopyStatusTypePending { + if d.maxRetries == -1 && copyStatus == blob.CopyStatusTypePending { if _, err := destBlobRef.AbortCopyFromURL(ctx, *resp.CopyID, nil); err != nil { return err } @@ -332,7 +370,7 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e return err } - if retryCount >= d.copyStatusPollMaxRetry { + if retryCount >= d.maxRetries { if _, err := destBlobRef.AbortCopyFromURL(ctx, *props.CopyID, nil); err != nil { return err } @@ -348,7 +386,7 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e } if copyStatus == blob.CopyStatusTypePending { - time.Sleep(d.copyStatusPollDelay * time.Duration(retryCount)) + time.Sleep(d.retryDelay * time.Duration(retryCount)) } retryCount++ } @@ -506,25 +544,30 @@ var _ storagedriver.FileWriter = &writer{} type writer struct { driver *driver path string - size int64 + size *atomic.Int64 bw *bufio.Writer closed bool committed bool cancelled bool } -func (d *driver) newWriter(ctx context.Context, path string, size int64) storagedriver.FileWriter { - return &writer{ +func (d *driver) newWriter(ctx context.Context, path string, size int64, eTag *azcore.ETag) storagedriver.FileWriter { + w := &writer{ driver: d, path: path, - size: size, - // TODO(milosgajdos): I'm not sure about the maxChunkSize - bw: bufio.NewWriterSize(&blockWriter{ - ctx: ctx, - client: d.client, - path: path, - }, maxChunkSize), + size: new(atomic.Int64), } + w.size.Store(size) + bw := bufio.NewWriterSize(&blockWriter{ + ctx: ctx, + client: d.client, + path: path, + size: w.size, + maxRetries: int32(d.maxRetries), + eTag: eTag, + }, maxChunkSize) + w.bw = bw + return w } func (w *writer) Write(p []byte) (int, error) { @@ -537,12 +580,11 @@ func (w *writer) Write(p []byte) (int, error) { } n, err := w.bw.Write(p) - w.size += int64(n) return n, err } func (w *writer) Size() int64 { - return w.size + return w.size.Load() } func (w *writer) Close() error { @@ -578,18 +620,148 @@ func (w *writer) Commit(ctx context.Context) error { } type blockWriter struct { - // We construct transient blockWriter objects to encapsulate a write - // and need to keep the context passed in to the original FileWriter.Write - ctx context.Context - client *container.Client - path string + client *container.Client + path string + maxRetries int32 + ctx context.Context + size *atomic.Int64 + eTag *azcore.ETag } func (bw *blockWriter) Write(p []byte) (int, error) { - blobRef := bw.client.NewAppendBlobClient(bw.path) - _, err := blobRef.AppendBlock(bw.ctx, streaming.NopCloser(bytes.NewReader(p)), nil) - if err != nil { - return 0, err + appendBlobRef := bw.client.NewAppendBlobClient(bw.path) + n := 0 + offsetRetryCount := int32(0) + + for n < len(p) { + appendPos := bw.size.Load() + chunkSize := min(maxChunkSize, len(p)-n) + timeoutFromCtx := false + ctxTimeoutNotify := withTimeoutNotification(bw.ctx, &timeoutFromCtx) + + resp, err := appendBlobRef.AppendBlock( + ctxTimeoutNotify, + streaming.NopCloser(bytes.NewReader(p[n:n+chunkSize])), + &appendblob.AppendBlockOptions{ + AppendPositionAccessConditions: &appendblob.AppendPositionAccessConditions{ + AppendPosition: to.Ptr(appendPos), + }, + AccessConditions: &blob.AccessConditions{ + ModifiedAccessConditions: &blob.ModifiedAccessConditions{ + IfMatch: bw.eTag, + }, + }, + }, + ) + if err == nil { + n += chunkSize // number of bytes uploaded in this call to Write() + bw.eTag = resp.ETag + bw.size.Add(int64(chunkSize)) // total size of the blob in the backend + continue + } + appendposFailed := bloberror.HasCode(err, bloberror.AppendPositionConditionNotMet) + etagFailed := bloberror.HasCode(err, bloberror.ConditionNotMet) + if !(appendposFailed || etagFailed) || !timeoutFromCtx { + // Error was not caused by an operation timeout, abort! + return n, fmt.Errorf("appending blob: %w", err) + } + + if offsetRetryCount >= bw.maxRetries { + return n, fmt.Errorf("max number of retries (%d) reached while handling backend operation timeout", bw.maxRetries) + } + + correctlyUploadedBytes, newEtag, err := bw.chunkUploadVerify(appendPos, p[n:n+chunkSize]) + if err != nil { + return n, fmt.Errorf("failed handling operation timeout during blob append: %w", err) + } + bw.eTag = newEtag + if correctlyUploadedBytes == 0 { + offsetRetryCount++ + continue + } + offsetRetryCount = 0 + + // MD5 is correct, data was uploaded. Let's bump the counters and + // continue with the upload + n += int(correctlyUploadedBytes) // number of bytes uploaded in this call to Write() + bw.size.Add(correctlyUploadedBytes) // total size of the blob in the backend } - return len(p), nil + + return n, nil +} + +// NOTE: this is more or less copy-pasta from the GitLab fix introduced by @vespian +// https://gitlab.com/gitlab-org/container-registry/-/commit/959132477ef719249270b87ce2a7a05abcd6e1ed?merge_request_iid=2059 +func (bw *blockWriter) chunkUploadVerify(appendPos int64, chunk []byte) (int64, *azcore.ETag, error) { + // NOTE(prozlach): We need to see if the chunk uploaded or not. As per + // the documentation, the operation __might__ have succeeded. There are + // three options: + // * chunk did not upload, the file size will be the same as bw.size. + // In this case we simply need to re-upload the last chunk + // * chunk or part of it was uploaded - we need to verify the contents + // of what has been uploaded with MD5 hash and either: + // * MD5 is ok - let's continue uploading data starting from the next + // chunk + // * MD5 is not OK - we have garbadge at the end of the file and + // AppendBlock supports only appending, we need to abort and return + // permament error to the caller. + + blobRef := bw.client.NewBlobClient(bw.path) + props, err := blobRef.GetProperties(bw.ctx, nil) + if err != nil { + return 0, nil, fmt.Errorf("determining the end of the blob: %v", err) + } + if props.ContentLength == nil { + return 0, nil, fmt.Errorf("ContentLength in blob properties is missing in reply: %v", err) + } + reuploadedBytes := *props.ContentLength - appendPos + if reuploadedBytes == 0 { + // NOTE(prozlach): This should never happen really and is here only as + // a precaution in case something changes in the future. The idea is + // that if the HTTP call did not succed and nothing was uploaded, then + // this code path is not going to be triggered as there will be no + // AppendPos condition violation during the retry. OTOH, if the write + // succeeded even partially, then the reuploadedBytes will be greater + // than zero. + return 0, props.ETag, nil + } + + resp, err := blobRef.DownloadStream( + bw.ctx, + &blob.DownloadStreamOptions{ + Range: blob.HTTPRange{Offset: appendPos, Count: reuploadedBytes}, + RangeGetContentMD5: to.Ptr(true), // we always upload <= 4MiB (i.e the maxChunkSize), so we can try to offload the MD5 calculation to azure + }, + ) + if err != nil { + return 0, nil, fmt.Errorf("determining the MD5 of the upload blob chunk: %v", err) + } + var uploadedMD5 []byte + // If upstream makes this extra check, then let's be paranoid too. + if len(resp.ContentMD5) > 0 { + uploadedMD5 = resp.ContentMD5 + } else { + // compute md5 + body := resp.NewRetryReader(bw.ctx, &blob.RetryReaderOptions{MaxRetries: bw.maxRetries}) + h := md5.New() // nolint: gosec // ok for content verification + _, err = io.Copy(h, body) + // nolint:errcheck + defer body.Close() + if err != nil { + return 0, nil, fmt.Errorf("calculating the MD5 of the uploaded blob chunk: %v", err) + } + uploadedMD5 = h.Sum(nil) + } + + h := md5.New() // nolint: gosec // ok for content verification + if _, err = io.Copy(h, bytes.NewReader(chunk)); err != nil { + return 0, nil, fmt.Errorf("calculating the MD5 of the local blob chunk: %v", err) + } + localMD5 := h.Sum(nil) + + if !bytes.Equal(uploadedMD5, localMD5) { + return 0, nil, fmt.Errorf("verifying contents of the uploaded blob chunk: %v", ErrCorruptedData) + } + + return reuploadedBytes, resp.ETag, nil } diff --git a/registry/storage/driver/azure/azure_test.go b/registry/storage/driver/azure/azure_test.go index 9182cd9fe..73a9d07c1 100644 --- a/registry/storage/driver/azure/azure_test.go +++ b/registry/storage/driver/azure/azure_test.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "os" + "strconv" "strings" "testing" @@ -12,23 +13,29 @@ import ( ) const ( - envAccountName = "AZURE_STORAGE_ACCOUNT_NAME" - envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY" - envContainer = "AZURE_STORAGE_CONTAINER" - envRealm = "AZURE_STORAGE_REALM" - envRootDirectory = "AZURE_ROOT_DIRECTORY" + envCredentialsType = "AZURE_STORAGE_CREDENTIALS_TYPE" + envAccountName = "AZURE_STORAGE_ACCOUNT_NAME" + envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY" + envContainer = "AZURE_STORAGE_CONTAINER" + envServiceURL = "AZURE_SERVICE_URL" + envRootDirectory = "AZURE_ROOT_DIRECTORY" + envSkipVerify = "AZURE_SKIP_VERIFY" ) -var azureDriverConstructor func() (storagedriver.StorageDriver, error) -var skipCheck func(tb testing.TB) +var ( + azureDriverConstructor func() (storagedriver.StorageDriver, error) + skipCheck func(tb testing.TB) +) func init() { var ( - accountName string - accountKey string - container string - realm string - rootDirectory string + accountName string + accountKey string + container string + serviceURL string + rootDirectory string + credentialsType string + skipVerify string ) config := []struct { @@ -39,8 +46,10 @@ func init() { {envAccountName, &accountName, false}, {envAccountKey, &accountKey, true}, {envContainer, &container, true}, - {envRealm, &realm, true}, + {envServiceURL, &serviceURL, false}, {envRootDirectory, &rootDirectory, true}, + {envCredentialsType, &credentialsType, true}, + {envSkipVerify, &skipVerify, true}, } missing := []string{} @@ -51,13 +60,24 @@ func init() { } } + skipVerifyBool, err := strconv.ParseBool(skipVerify) + if err != nil { + // NOTE(milosgajdos): if we fail to parse AZURE_SKIP_VERIFY + // we default to verifying TLS certs + skipVerifyBool = false + } + azureDriverConstructor = func() (storagedriver.StorageDriver, error) { parameters := map[string]interface{}{ "container": container, "accountname": accountName, "accountkey": accountKey, - "realm": realm, + "serviceurl": serviceURL, "rootdirectory": rootDirectory, + "credentials": map[string]any{ + "type": credentialsType, + }, + "skipverify": skipVerifyBool, } params, err := NewParameters(parameters) if err != nil { @@ -78,7 +98,11 @@ func init() { func TestAzureDriverSuite(t *testing.T) { skipCheck(t) - testsuites.Driver(t, azureDriverConstructor) + skipVerify, err := strconv.ParseBool(os.Getenv(envSkipVerify)) + if err != nil { + skipVerify = false + } + testsuites.Driver(t, azureDriverConstructor, skipVerify) } func BenchmarkAzureDriverSuite(b *testing.B) { @@ -161,26 +185,26 @@ func TestParamParsing(t *testing.T) { } } input := []map[string]interface{}{ - {"accountname": "acc1", "accountkey": "k1", "container": "c1", "copy_status_poll_max_retry": 1, "copy_status_poll_delay": "10ms"}, + {"accountname": "acc1", "accountkey": "k1", "container": "c1", "max_retries": 1, "retry_delay": "10ms"}, {"accountname": "acc1", "container": "c1", "credentials": map[string]interface{}{"type": "default"}}, {"accountname": "acc1", "container": "c1", "credentials": map[string]interface{}{"type": "client_secret", "clientid": "c1", "tenantid": "t1", "secret": "s1"}}, } - expecteds := []Parameters{ + expecteds := []DriverParameters{ { Container: "c1", AccountName: "acc1", AccountKey: "k1", Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net", - CopyStatusPollMaxRetry: 1, CopyStatusPollDelay: "10ms", + MaxRetries: 1, RetryDelay: "10ms", }, { Container: "c1", AccountName: "acc1", Credentials: Credentials{Type: "default"}, Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net", - CopyStatusPollMaxRetry: 5, CopyStatusPollDelay: "100ms", + MaxRetries: 5, RetryDelay: "100ms", }, { Container: "c1", AccountName: "acc1", Credentials: Credentials{Type: "client_secret", ClientID: "c1", TenantID: "t1", Secret: "s1"}, Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net", - CopyStatusPollMaxRetry: 5, CopyStatusPollDelay: "100ms", + MaxRetries: 5, RetryDelay: "100ms", }, } for i, expected := range expecteds { diff --git a/registry/storage/driver/azure/parser.go b/registry/storage/driver/azure/parser.go index c463ac32a..5ecf19060 100644 --- a/registry/storage/driver/azure/parser.go +++ b/registry/storage/driver/azure/parser.go @@ -8,33 +8,42 @@ import ( ) const ( - defaultRealm = "core.windows.net" - defaultCopyStatusPollMaxRetry = 5 - defaultCopyStatusPollDelay = "100ms" + defaultRealm = "core.windows.net" + defaultMaxRetries = 5 + defaultRetryDelay = "100ms" +) + +type CredentialsType string + +const ( + CredentialsTypeClientSecret = "client_secret" + CredentialsTypeSharedKey = "shared_key" + CredentialsTypeDefault = "default_credentials" ) type Credentials struct { - Type string `mapstructure:"type"` - ClientID string `mapstructure:"clientid"` - TenantID string `mapstructure:"tenantid"` - Secret string `mapstructure:"secret"` + Type CredentialsType `mapstructure:"type"` + ClientID string `mapstructure:"clientid"` + TenantID string `mapstructure:"tenantid"` + Secret string `mapstructure:"secret"` } -type Parameters struct { - Container string `mapstructure:"container"` - AccountName string `mapstructure:"accountname"` - AccountKey string `mapstructure:"accountkey"` - Credentials Credentials `mapstructure:"credentials"` - ConnectionString string `mapstructure:"connectionstring"` - Realm string `mapstructure:"realm"` - RootDirectory string `mapstructure:"rootdirectory"` - ServiceURL string `mapstructure:"serviceurl"` - CopyStatusPollMaxRetry int `mapstructure:"copy_status_poll_max_retry"` - CopyStatusPollDelay string `mapstructure:"copy_status_poll_delay"` +type DriverParameters struct { + Credentials Credentials `mapstructure:"credentials"` + Container string `mapstructure:"container"` + AccountName string `mapstructure:"accountname"` + AccountKey string `mapstructure:"accountkey"` + ConnectionString string `mapstructure:"connectionstring"` + Realm string `mapstructure:"realm"` + RootDirectory string `mapstructure:"rootdirectory"` + ServiceURL string `mapstructure:"serviceurl"` + MaxRetries int `mapstructure:"max_retries"` + RetryDelay string `mapstructure:"retry_delay"` + SkipVerify bool `mapstructure:"skipverify"` } -func NewParameters(parameters map[string]interface{}) (*Parameters, error) { - params := Parameters{ +func NewParameters(parameters map[string]interface{}) (*DriverParameters, error) { + params := DriverParameters{ Realm: defaultRealm, } if err := mapstructure.Decode(parameters, ¶ms); err != nil { @@ -49,11 +58,11 @@ func NewParameters(parameters map[string]interface{}) (*Parameters, error) { if params.ServiceURL == "" { params.ServiceURL = fmt.Sprintf("https://%s.blob.%s", params.AccountName, params.Realm) } - if params.CopyStatusPollMaxRetry == 0 { - params.CopyStatusPollMaxRetry = defaultCopyStatusPollMaxRetry + if params.MaxRetries == 0 { + params.MaxRetries = defaultMaxRetries } - if params.CopyStatusPollDelay == "" { - params.CopyStatusPollDelay = defaultCopyStatusPollDelay + if params.RetryDelay == "" { + params.RetryDelay = defaultRetryDelay } return ¶ms, nil } diff --git a/registry/storage/driver/azure/retry_notification_policy.go b/registry/storage/driver/azure/retry_notification_policy.go new file mode 100644 index 000000000..4e97970a1 --- /dev/null +++ b/registry/storage/driver/azure/retry_notification_policy.go @@ -0,0 +1,91 @@ +package azure + +import ( + "context" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" +) + +// Inspired by/credit goes to https://github.com/Azure/azure-storage-azcopy/blob/97ab7b92e766ad48965ac2933495dff1b04fb2a7/ste/xferRetryNotificationPolicy.go +type contextKey struct { + name string +} + +var ( + timeoutNotifyContextKey = contextKey{"timeoutNotify"} + retryNotifyContextKey = contextKey{"retryNotify"} +) + +// retryNotificationReceiver should be implemented by code that wishes to be +// notified when a retry happens. Such code must register itself into the +// context, using withRetryNotification, so that the RetryNotificationPolicy +// can invoke the callback when necessary. +type retryNotificationReceiver interface { + RetryCallback() +} + +// withTimeoutNotification returns a context that contains indication of a +// timeout. The retryNotificationPolicy will then set the timeout flag when a +// timeout happens +func withTimeoutNotification(ctx context.Context, timeout *bool) context.Context { + return context.WithValue(ctx, timeoutNotifyContextKey, timeout) +} + +// withRetryNotifier returns a context that contains a retry notifier. The +// retryNotificationPolicy will then invoke the callback when a retry happens +func withRetryNotification(ctx context.Context, r retryNotificationReceiver) context.Context { // nolint: unused // may become useful at some point + return context.WithValue(ctx, retryNotifyContextKey, r) +} + +// PolicyFunc is a type that implements the Policy interface. +// Use this type when implementing a stateless policy as a first-class function. +type PolicyFunc func(*policy.Request) (*http.Response, error) + +// Do implements the Policy interface on policyFunc. +func (pf PolicyFunc) Do(req *policy.Request) (*http.Response, error) { + return pf(req) +} + +func newRetryNotificationPolicy() policy.Policy { + getErrorCode := func(resp *http.Response) string { + // NOTE(prozlach): This is a hacky way to handle all possible cases of + // emitting error by the Azure backend. + // In theory we could look just at `x-ms-error-code` HTTP header, but + // in practice Azure SDK also looks at the body and decodes it as JSON + // or XML in case when the header is absent. + // So the idea is to piggy-back on the runtime.NewResponseError that + // will do the proper decoding for us and just return the ErrorCode + // field instead. + return runtime.NewResponseError(resp).(*azcore.ResponseError).ErrorCode + } + + return PolicyFunc(func(req *policy.Request) (*http.Response, error) { + response, err := req.Next() // Make the request + + if response == nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusServiceUnavailable: + // Grab the notification callback out of the context and, if its there, call it + if notifier, ok := req.Raw().Context().Value(retryNotifyContextKey).(retryNotificationReceiver); ok { + notifier.RetryCallback() + } + case http.StatusInternalServerError: + errorCodeHeader := getErrorCode(response) + if bloberror.Code(errorCodeHeader) != bloberror.OperationTimedOut { + break + } + + if timeout, ok := req.Raw().Context().Value(timeoutNotifyContextKey).(*bool); ok { + *timeout = true + } + } + return response, err + }) +} diff --git a/registry/storage/driver/filesystem/driver_test.go b/registry/storage/driver/filesystem/driver_test.go index e236aa47f..71e2273cc 100644 --- a/registry/storage/driver/filesystem/driver_test.go +++ b/registry/storage/driver/filesystem/driver_test.go @@ -19,7 +19,7 @@ func newDriverConstructor(tb testing.TB) testsuites.DriverConstructor { } func TestFilesystemDriverSuite(t *testing.T) { - testsuites.Driver(t, newDriverConstructor(t)) + testsuites.Driver(t, newDriverConstructor(t), false) } func BenchmarkFilesystemDriverSuite(b *testing.B) { diff --git a/registry/storage/driver/gcs/gcs_test.go b/registry/storage/driver/gcs/gcs_test.go index 9978d1da1..926117807 100644 --- a/registry/storage/driver/gcs/gcs_test.go +++ b/registry/storage/driver/gcs/gcs_test.go @@ -93,7 +93,7 @@ func newDriverConstructor(tb testing.TB) testsuites.DriverConstructor { func TestGCSDriverSuite(t *testing.T) { skipCheck(t) - testsuites.Driver(t, newDriverConstructor(t)) + testsuites.Driver(t, newDriverConstructor(t), false) } func BenchmarkGCSDriverSuite(b *testing.B) { diff --git a/registry/storage/driver/inmemory/driver_test.go b/registry/storage/driver/inmemory/driver_test.go index b1ede2c3b..54b6fe0f3 100644 --- a/registry/storage/driver/inmemory/driver_test.go +++ b/registry/storage/driver/inmemory/driver_test.go @@ -12,7 +12,7 @@ func newDriverConstructor() (storagedriver.StorageDriver, error) { } func TestInMemoryDriverSuite(t *testing.T) { - testsuites.Driver(t, newDriverConstructor) + testsuites.Driver(t, newDriverConstructor, false) } func BenchmarkInMemoryDriverSuite(b *testing.B) { diff --git a/registry/storage/driver/s3-aws/s3_test.go b/registry/storage/driver/s3-aws/s3_test.go index ab7094708..ff61ef1ad 100644 --- a/registry/storage/driver/s3-aws/s3_test.go +++ b/registry/storage/driver/s3-aws/s3_test.go @@ -154,7 +154,11 @@ func newDriverConstructor(tb testing.TB) testsuites.DriverConstructor { func TestS3DriverSuite(t *testing.T) { skipCheck(t) - testsuites.Driver(t, newDriverConstructor(t)) + skipVerify, err := strconv.ParseBool(os.Getenv("S3_SKIP_VERIFY")) + if err != nil { + skipVerify = false + } + testsuites.Driver(t, newDriverConstructor(t), skipVerify) } func BenchmarkS3DriverSuite(b *testing.B) { diff --git a/registry/storage/driver/testsuites/testsuites.go b/registry/storage/driver/testsuites/testsuites.go index b1221a64b..c61ecc6cd 100644 --- a/registry/storage/driver/testsuites/testsuites.go +++ b/registry/storage/driver/testsuites/testsuites.go @@ -5,6 +5,7 @@ import ( "context" crand "crypto/rand" "crypto/sha256" + "crypto/tls" "io" "math/rand" "net/http" @@ -43,14 +44,16 @@ type DriverSuite struct { Constructor DriverConstructor Teardown DriverTeardown storagedriver.StorageDriver - ctx context.Context + ctx context.Context + skipVerify bool } // Driver runs [DriverSuite] for the given [DriverConstructor]. -func Driver(t *testing.T, driverConstructor DriverConstructor) { +func Driver(t *testing.T, driverConstructor DriverConstructor, skipVerify bool) { suite.Run(t, &DriverSuite{ Constructor: driverConstructor, ctx: context.Background(), + skipVerify: skipVerify, }) } @@ -739,7 +742,19 @@ func (suite *DriverSuite) TestRedirectURL() { } suite.Require().NoError(err) - response, err := http.Get(url) + client := http.DefaultClient + if suite.skipVerify { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + client = &http.Client{ + Transport: httpTransport, + } + } + + req, err := http.NewRequest(http.MethodGet, url, nil) + suite.Require().NoError(err) + + response, err := client.Do(req) suite.Require().NoError(err) defer response.Body.Close() @@ -753,7 +768,10 @@ func (suite *DriverSuite) TestRedirectURL() { } suite.Require().NoError(err) - response, err = http.Head(url) + req, err = http.NewRequest(http.MethodHead, url, nil) + suite.Require().NoError(err) + + response, err = client.Do(req) suite.Require().NoError(err) defer response.Body.Close() suite.Require().Equal(200, response.StatusCode) @@ -1323,7 +1341,7 @@ func (suite *DriverSuite) writeReadCompareStreams(filename string, contents []by var ( filenameChars = []byte("abcdefghijklmnopqrstuvwxyz0123456789") - separatorChars = []byte("._-") + separatorChars = []byte("-") ) func randomPath(length int64) string { diff --git a/tests/docker-compose-azure-blob-store.yaml b/tests/docker-compose-azure-blob-store.yaml new file mode 100644 index 000000000..c1af9e0dd --- /dev/null +++ b/tests/docker-compose-azure-blob-store.yaml @@ -0,0 +1,59 @@ +services: + cert-init: + image: alpine/mkcert + volumes: + - ./certs:/certs + entrypoint: /bin/sh + command: > + -c " + mkdir -p /certs && + cd /certs && + mkcert -install && + mkcert 127.0.0.1 && + chmod 644 /certs/127.0.0.1.pem /certs/127.0.0.1-key.pem + " + + azurite: + image: mcr.microsoft.com/azure-storage/azurite + ports: + - "10000:10000" + volumes: + - ./certs:/workspace + command: > + azurite-blob + --blobHost 0.0.0.0 + --oauth basic + --loose 0.0.0.0 + --cert /workspace/127.0.0.1.pem + --key /workspace/127.0.0.1-key.pem + depends_on: + cert-init: + condition: service_completed_successfully + healthcheck: + # NOTE(milosgajdos): Azurite does not have a healtcheck endpoint + # so we are temporarilty working around it by using a simple node command. + # We want to make sure the API is up and running so we are deliberately + # ignoring that the healthcheck API request fails authorization check + test: [ + "CMD", + "node", + "-e", + "const http = require('https'); const options = { hostname: '127.0.0.1', port: 10000, path: '/', method: 'GET', rejectUnauthorized: false }; const req = http.request(options, res => process.exit(0)); req.on('error', err => process.exit(1)); req.end();" + ] + interval: 5s + timeout: 5s + retries: 10 + + azurite-init: + image: mcr.microsoft.com/azure-cli + depends_on: + azurite: + condition: service_healthy + volumes: + - ./certs:/certs + environment: + AZURE_STORAGE_CONNECTION_STRING: "DefaultEndpointsProtocol=https;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=https://azurite:10000/devstoreaccount1;" + AZURE_CLI_DISABLE_CONNECTION_VERIFICATION: "1" + entrypoint: > + /bin/bash -c " + az storage container create --name containername --connection-string \"$$AZURE_STORAGE_CONNECTION_STRING\" --debug"