This commit is contained in:
Thomas 2025-06-10 22:56:01 -04:00 committed by GitHub
commit 1a9039b7e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -653,103 +653,117 @@ func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.Read
// with non-zero committed content. // with non-zero committed content.
func (d *driver) Writer(ctx context.Context, path string, appendMode bool) (storagedriver.FileWriter, error) { func (d *driver) Writer(ctx context.Context, path string, appendMode bool) (storagedriver.FileWriter, error) {
key := d.s3Path(path) key := d.s3Path(path)
if !appendMode { if !appendMode {
// TODO (brianbland): cancel other uploads at this path // TODO (brianbland): cancel other uploads at this path
resp, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{ uploadID, err := d.createMultipartUpload(ctx, key)
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
ContentType: d.getContentType(),
ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(),
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return d.newWriter(ctx, key, *resp.UploadId, nil), nil return d.newWriter(ctx, key, *uploadID, nil), nil
} }
listMultipartUploadsInput := &s3.ListMultipartUploadsInput{ uploadID, err := d.inProgressUpload(ctx, path)
Bucket: aws.String(d.Bucket), if err != nil {
Prefix: aws.String(key), return nil, err
} }
for {
resp, err := d.S3.ListMultipartUploadsWithContext(ctx, listMultipartUploadsInput) if uploadID == nil {
uploadID, err := d.createMultipartUpload(ctx, key)
if err != nil { if err != nil {
return nil, parseError(path, err) return nil, err
}
return d.newWriter(ctx, key, *uploadID, nil), nil
}
parts, err := d.listParts(ctx, path, uploadID)
if err != nil {
return nil, err
}
return d.newWriter(ctx, key, *uploadID, parts), nil
}
// createMultiPartUpload creates a new multipart upload for the associated
// bucket and the given key.
func (d *driver) createMultipartUpload(ctx context.Context, key string) (*string, error) {
res, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
ContentType: d.getContentType(),
ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(),
})
if err != nil {
return nil, err
}
return res.UploadId, nil
}
// inProgressUpload finds an in-progress multipart upload for the given path.
// If there is no in-progress upload, a nil upload ID and no error is returned.
func (d *driver) inProgressUpload(ctx context.Context, path string) (uploadID *string, err error) {
var empty bool
if err := d.S3.ListMultipartUploadsPagesWithContext(ctx, &s3.ListMultipartUploadsInput{
Bucket: aws.String(d.Bucket),
Prefix: aws.String(d.s3Path(path)),
}, func(page *s3.ListMultipartUploadsOutput, lastPage bool) bool {
// This condition is only valid for the first page. Subsequent
// pages are guaranteed not to be empty.
if len(page.Uploads) == 0 {
empty = true
return false
} }
// resp.Uploads can only be empty on the first call for _, upload := range page.Uploads {
// if there were no more results to return after the first call, resp.IsTruncated would have been false if *upload.Key == d.s3Path(path) {
// and the loop would be exited without recalling ListMultipartUploads uploadID = upload.UploadId
if len(resp.Uploads) == 0 { return false
fi, err := d.Stat(ctx, path)
if err != nil {
return nil, parseError(path, err)
}
if fi.Size() == 0 {
resp, err := d.S3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
ContentType: d.getContentType(),
ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(),
})
if err != nil {
return nil, err
}
return d.newWriter(ctx, key, *resp.UploadId, nil), nil
}
return nil, storagedriver.Error{
DriverName: driverName,
Detail: fmt.Errorf("append to zero-size path %s unsupported", path),
} }
} }
return true
}); err != nil {
return nil, fmt.Errorf("list multipart uploads pages: %w", err)
}
var allParts []*s3.Part if !empty {
for _, multi := range resp.Uploads { if uploadID != nil {
if key != *multi.Key { return uploadID, nil
continue
}
partsList, err := d.S3.ListPartsWithContext(ctx, &s3.ListPartsInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
UploadId: multi.UploadId,
})
if err != nil {
return nil, parseError(path, err)
}
allParts = append(allParts, partsList.Parts...)
for *partsList.IsTruncated {
partsList, err = d.S3.ListPartsWithContext(ctx, &s3.ListPartsInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
UploadId: multi.UploadId,
PartNumberMarker: partsList.NextPartNumberMarker,
})
if err != nil {
return nil, parseError(path, err)
}
allParts = append(allParts, partsList.Parts...)
}
return d.newWriter(ctx, key, *multi.UploadId, allParts), nil
} }
return nil, storagedriver.PathNotFoundError{Path: path}
}
// resp.NextUploadIdMarker must have at least one element or we would have returned not found fi, err := d.Stat(ctx, path)
listMultipartUploadsInput.UploadIdMarker = resp.NextUploadIdMarker if err != nil {
return nil, parseError(path, err)
}
// from the s3 api docs, IsTruncated "specifies whether (true) or not (false) all of the results were returned" if fi.Size() > 0 {
// if everything has been returned, break return nil, storagedriver.Error{
if resp.IsTruncated == nil || !*resp.IsTruncated { DriverName: driverName,
break Detail: fmt.Errorf("append to non zero-size path %s unsupported", path),
} }
} }
return nil, storagedriver.PathNotFoundError{Path: path}
// An empty file can be overwritten with a new upload.
return nil, nil
}
// listParts lists all parts associated with the uploadID. The parts are not
// guaranteed to be sorted.
func (d *driver) listParts(ctx context.Context, path string, uploadID *string) (parts []*s3.Part, err error) {
if err := d.S3.ListPartsPagesWithContext(ctx, &s3.ListPartsInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)),
UploadId: uploadID,
}, func(page *s3.ListPartsOutput, lastPage bool) bool {
parts = append(parts, page.Parts...)
return true
}); err != nil {
return nil, fmt.Errorf("list parts pages: %w", err)
}
return parts, nil
} }
func (d *driver) statHead(ctx context.Context, path string) (*storagedriver.FileInfoFields, error) { func (d *driver) statHead(ctx context.Context, path string) (*storagedriver.FileInfoFields, error) {