From 052404a015c36e563d06546aadce4112d81b89b2 Mon Sep 17 00:00:00 2001 From: Thomas Way Date: Fri, 9 Aug 2024 19:06:25 +0100 Subject: [PATCH] chore(registry/storage/driver/s3-aws): refactor writer creation The logic is identical, but has been separated out and reorganised for clarity. Signed-off-by: Thomas Way --- registry/storage/driver/s3-aws/s3.go | 170 +++++++++++++++------------ 1 file changed, 92 insertions(+), 78 deletions(-) diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index 48db5f4a1..435f11a81 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -665,103 +665,117 @@ func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.Read // with non-zero committed content. func (d *driver) Writer(ctx context.Context, path string, appendMode bool) (storagedriver.FileWriter, error) { key := d.s3Path(path) + if !appendMode { // TODO (brianbland): cancel other uploads at this path - resp, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{ - Bucket: aws.String(d.Bucket), - Key: aws.String(key), - ContentType: d.getContentType(), - ACL: d.getACL(), - ServerSideEncryption: d.getEncryptionMode(), - SSEKMSKeyId: d.getSSEKMSKeyID(), - StorageClass: d.getStorageClass(), - }) + uploadID, err := d.createMultipartUpload(ctx, key) if err != nil { return nil, err } - return d.newWriter(ctx, key, *resp.UploadId, nil), nil + return d.newWriter(ctx, key, *uploadID, nil), nil } - listMultipartUploadsInput := &s3.ListMultipartUploadsInput{ - Bucket: aws.String(d.Bucket), - Prefix: aws.String(key), + uploadID, err := d.inProgressUpload(ctx, path) + if err != nil { + return nil, err } - for { - resp, err := d.S3.ListMultipartUploadsWithContext(ctx, listMultipartUploadsInput) + + if uploadID == nil { + uploadID, err := d.createMultipartUpload(ctx, key) if err != nil { - return nil, parseError(path, err) + return nil, err + } + return d.newWriter(ctx, key, *uploadID, nil), nil + } + + parts, err := d.listParts(ctx, path, uploadID) + if err != nil { + return nil, err + } + return d.newWriter(ctx, key, *uploadID, parts), nil +} + +// createMultiPartUpload creates a new multipart upload for the associated +// bucket and the given key. +func (d *driver) createMultipartUpload(ctx context.Context, key string) (*string, error) { + res, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{ + Bucket: aws.String(d.Bucket), + Key: aws.String(key), + ContentType: d.getContentType(), + ACL: d.getACL(), + ServerSideEncryption: d.getEncryptionMode(), + SSEKMSKeyId: d.getSSEKMSKeyID(), + StorageClass: d.getStorageClass(), + }) + if err != nil { + return nil, err + } + return res.UploadId, nil +} + +// inProgressUpload finds an in-progress multipart upload for the given path. +// If there is no in-progress upload, a nil upload ID and no error is returned. +func (d *driver) inProgressUpload(ctx context.Context, path string) (uploadID *string, err error) { + var empty bool + if err := d.S3.ListMultipartUploadsPagesWithContext(ctx, &s3.ListMultipartUploadsInput{ + Bucket: aws.String(d.Bucket), + Prefix: aws.String(d.s3Path(path)), + }, func(page *s3.ListMultipartUploadsOutput, lastPage bool) bool { + // This condition is only valid for the first page. Subsequent + // pages are guaranteed not to be empty. + if len(page.Uploads) == 0 { + empty = true + return false } - // resp.Uploads can only be empty on the first call - // if there were no more results to return after the first call, resp.IsTruncated would have been false - // and the loop would be exited without recalling ListMultipartUploads - if len(resp.Uploads) == 0 { - fi, err := d.Stat(ctx, path) - if err != nil { - return nil, parseError(path, err) - } - - if fi.Size() == 0 { - resp, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{ - Bucket: aws.String(d.Bucket), - Key: aws.String(key), - ContentType: d.getContentType(), - ACL: d.getACL(), - ServerSideEncryption: d.getEncryptionMode(), - SSEKMSKeyId: d.getSSEKMSKeyID(), - StorageClass: d.getStorageClass(), - }) - if err != nil { - return nil, err - } - return d.newWriter(ctx, key, *resp.UploadId, nil), nil - } - return nil, storagedriver.Error{ - DriverName: driverName, - Detail: fmt.Errorf("append to zero-size path %s unsupported", path), + for _, upload := range page.Uploads { + if *upload.Key == d.s3Path(path) { + uploadID = upload.UploadId + return false } } + return true + }); err != nil { + return nil, fmt.Errorf("list multipart uploads pages: %w", err) + } - var allParts []*s3.Part - for _, multi := range resp.Uploads { - if key != *multi.Key { - continue - } - - partsList, err := d.S3.ListPartsWithContext(ctx, &s3.ListPartsInput{ - Bucket: aws.String(d.Bucket), - Key: aws.String(key), - UploadId: multi.UploadId, - }) - if err != nil { - return nil, parseError(path, err) - } - allParts = append(allParts, partsList.Parts...) - for *partsList.IsTruncated { - partsList, err = d.S3.ListPartsWithContext(ctx, &s3.ListPartsInput{ - Bucket: aws.String(d.Bucket), - Key: aws.String(key), - UploadId: multi.UploadId, - PartNumberMarker: partsList.NextPartNumberMarker, - }) - if err != nil { - return nil, parseError(path, err) - } - allParts = append(allParts, partsList.Parts...) - } - return d.newWriter(ctx, key, *multi.UploadId, allParts), nil + if !empty { + if uploadID != nil { + return uploadID, nil } + return nil, storagedriver.PathNotFoundError{Path: path} + } - // resp.NextUploadIdMarker must have at least one element or we would have returned not found - listMultipartUploadsInput.UploadIdMarker = resp.NextUploadIdMarker + fi, err := d.Stat(ctx, path) + if err != nil { + return nil, parseError(path, err) + } - // from the s3 api docs, IsTruncated "specifies whether (true) or not (false) all of the results were returned" - // if everything has been returned, break - if resp.IsTruncated == nil || !*resp.IsTruncated { - break + if fi.Size() > 0 { + return nil, storagedriver.Error{ + DriverName: driverName, + Detail: fmt.Errorf("append to non zero-size path %s unsupported", path), } } - return nil, storagedriver.PathNotFoundError{Path: path} + + // An empty file can be overwritten with a new upload. + return nil, nil +} + +// listParts lists all parts associated with the uploadID. The parts are not +// guaranteed to be sorted. +func (d *driver) listParts(ctx context.Context, path string, uploadID *string) (parts []*s3.Part, err error) { + if err := d.S3.ListPartsPagesWithContext(ctx, &s3.ListPartsInput{ + Bucket: aws.String(d.Bucket), + Key: aws.String(d.s3Path(path)), + UploadId: uploadID, + }, func(page *s3.ListPartsOutput, lastPage bool) bool { + parts = append(parts, page.Parts...) + return true + }); err != nil { + return nil, fmt.Errorf("list parts pages: %w", err) + } + return parts, nil } func (d *driver) statHead(ctx context.Context, path string) (*storagedriver.FileInfoFields, error) {