chore(registry/storage/driver/s3-aws): refactor writer creation

The logic is identical, but has been separated out and reorganised for
clarity.

Signed-off-by: Thomas Way <thomas@6f.io>
This commit is contained in:
Thomas Way 2024-08-09 19:06:25 +01:00
parent f0bd0f6899
commit 052404a015
No known key found for this signature in database
GPG Key ID: F98E7FF1F9F8C217

View File

@ -665,103 +665,117 @@ func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.Read
// with non-zero committed content.
func (d *driver) Writer(ctx context.Context, path string, appendMode bool) (storagedriver.FileWriter, error) {
key := d.s3Path(path)
if !appendMode {
// TODO (brianbland): cancel other uploads at this path
resp, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
ContentType: d.getContentType(),
ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(),
})
uploadID, err := d.createMultipartUpload(ctx, key)
if err != nil {
return nil, err
}
return d.newWriter(ctx, key, *resp.UploadId, nil), nil
return d.newWriter(ctx, key, *uploadID, nil), nil
}
listMultipartUploadsInput := &s3.ListMultipartUploadsInput{
Bucket: aws.String(d.Bucket),
Prefix: aws.String(key),
uploadID, err := d.inProgressUpload(ctx, path)
if err != nil {
return nil, err
}
for {
resp, err := d.S3.ListMultipartUploadsWithContext(ctx, listMultipartUploadsInput)
if uploadID == nil {
uploadID, err := d.createMultipartUpload(ctx, key)
if err != nil {
return nil, parseError(path, err)
return nil, err
}
return d.newWriter(ctx, key, *uploadID, nil), nil
}
parts, err := d.listParts(ctx, path, uploadID)
if err != nil {
return nil, err
}
return d.newWriter(ctx, key, *uploadID, parts), nil
}
// createMultiPartUpload creates a new multipart upload for the associated
// bucket and the given key.
func (d *driver) createMultipartUpload(ctx context.Context, key string) (*string, error) {
res, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
ContentType: d.getContentType(),
ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(),
})
if err != nil {
return nil, err
}
return res.UploadId, nil
}
// inProgressUpload finds an in-progress multipart upload for the given path.
// If there is no in-progress upload, a nil upload ID and no error is returned.
func (d *driver) inProgressUpload(ctx context.Context, path string) (uploadID *string, err error) {
var empty bool
if err := d.S3.ListMultipartUploadsPagesWithContext(ctx, &s3.ListMultipartUploadsInput{
Bucket: aws.String(d.Bucket),
Prefix: aws.String(d.s3Path(path)),
}, func(page *s3.ListMultipartUploadsOutput, lastPage bool) bool {
// This condition is only valid for the first page. Subsequent
// pages are guaranteed not to be empty.
if len(page.Uploads) == 0 {
empty = true
return false
}
// resp.Uploads can only be empty on the first call
// if there were no more results to return after the first call, resp.IsTruncated would have been false
// and the loop would be exited without recalling ListMultipartUploads
if len(resp.Uploads) == 0 {
fi, err := d.Stat(ctx, path)
if err != nil {
return nil, parseError(path, err)
}
if fi.Size() == 0 {
resp, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
ContentType: d.getContentType(),
ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(),
})
if err != nil {
return nil, err
}
return d.newWriter(ctx, key, *resp.UploadId, nil), nil
}
return nil, storagedriver.Error{
DriverName: driverName,
Detail: fmt.Errorf("append to zero-size path %s unsupported", path),
for _, upload := range page.Uploads {
if *upload.Key == d.s3Path(path) {
uploadID = upload.UploadId
return false
}
}
return true
}); err != nil {
return nil, fmt.Errorf("list multipart uploads pages: %w", err)
}
var allParts []*s3.Part
for _, multi := range resp.Uploads {
if key != *multi.Key {
continue
}
partsList, err := d.S3.ListPartsWithContext(ctx, &s3.ListPartsInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
UploadId: multi.UploadId,
})
if err != nil {
return nil, parseError(path, err)
}
allParts = append(allParts, partsList.Parts...)
for *partsList.IsTruncated {
partsList, err = d.S3.ListPartsWithContext(ctx, &s3.ListPartsInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
UploadId: multi.UploadId,
PartNumberMarker: partsList.NextPartNumberMarker,
})
if err != nil {
return nil, parseError(path, err)
}
allParts = append(allParts, partsList.Parts...)
}
return d.newWriter(ctx, key, *multi.UploadId, allParts), nil
if !empty {
if uploadID != nil {
return uploadID, nil
}
return nil, storagedriver.PathNotFoundError{Path: path}
}
// resp.NextUploadIdMarker must have at least one element or we would have returned not found
listMultipartUploadsInput.UploadIdMarker = resp.NextUploadIdMarker
fi, err := d.Stat(ctx, path)
if err != nil {
return nil, parseError(path, err)
}
// from the s3 api docs, IsTruncated "specifies whether (true) or not (false) all of the results were returned"
// if everything has been returned, break
if resp.IsTruncated == nil || !*resp.IsTruncated {
break
if fi.Size() > 0 {
return nil, storagedriver.Error{
DriverName: driverName,
Detail: fmt.Errorf("append to non zero-size path %s unsupported", path),
}
}
return nil, storagedriver.PathNotFoundError{Path: path}
// An empty file can be overwritten with a new upload.
return nil, nil
}
// listParts lists all parts associated with the uploadID. The parts are not
// guaranteed to be sorted.
func (d *driver) listParts(ctx context.Context, path string, uploadID *string) (parts []*s3.Part, err error) {
if err := d.S3.ListPartsPagesWithContext(ctx, &s3.ListPartsInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)),
UploadId: uploadID,
}, func(page *s3.ListPartsOutput, lastPage bool) bool {
parts = append(parts, page.Parts...)
return true
}); err != nil {
return nil, fmt.Errorf("list parts pages: %w", err)
}
return parts, nil
}
func (d *driver) statHead(ctx context.Context, path string) (*storagedriver.FileInfoFields, error) {