Merge pull request #3839 from kirat-singh/feature.azure-sdk-update

Update Azure SDK and support additional authentication schemes
This commit is contained in:
Milos Gajdos
2023-04-25 19:35:34 +01:00
committed by GitHub
365 changed files with 44060 additions and 21016 deletions

View File

@@ -94,7 +94,7 @@ func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor)
// the writer and canceling the operation.
func (bw *blobWriter) Cancel(ctx context.Context) error {
dcontext.GetLogger(ctx).Debug("(*blobWriter).Cancel")
if err := bw.fileWriter.Cancel(); err != nil {
if err := bw.fileWriter.Cancel(ctx); err != nil {
return err
}

View File

@@ -8,7 +8,6 @@ import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
@@ -16,22 +15,22 @@ import (
"github.com/distribution/distribution/v3/registry/storage/driver/base"
"github.com/distribution/distribution/v3/registry/storage/driver/factory"
azure "github.com/Azure/azure-sdk-for-go/storage"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
)
const driverName = "azure"
const (
paramAccountName = "accountname"
paramAccountKey = "accountkey"
paramContainer = "container"
paramRealm = "realm"
maxChunkSize = 4 * 1024 * 1024
maxChunkSize = 4 * 1024 * 1024
)
type driver struct {
client azure.BlobStorageClient
container string
azClient *azureClient
client *container.Client
rootDirectory string
}
type baseEmbed struct{ base.Base }
@@ -47,53 +46,24 @@ func init() {
type azureDriverFactory struct{}
func (factory *azureDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
// FromParameters constructs a new Driver with a given parameters map.
func FromParameters(parameters map[string]interface{}) (*Driver, error) {
accountName, ok := parameters[paramAccountName]
if !ok || fmt.Sprint(accountName) == "" {
return nil, fmt.Errorf("no %s parameter provided", paramAccountName)
}
accountKey, ok := parameters[paramAccountKey]
if !ok || fmt.Sprint(accountKey) == "" {
return nil, fmt.Errorf("no %s parameter provided", paramAccountKey)
}
container, ok := parameters[paramContainer]
if !ok || fmt.Sprint(container) == "" {
return nil, fmt.Errorf("no %s parameter provided", paramContainer)
}
realm, ok := parameters[paramRealm]
if !ok || fmt.Sprint(realm) == "" {
realm = azure.DefaultBaseURL
}
return New(fmt.Sprint(accountName), fmt.Sprint(accountKey), fmt.Sprint(container), fmt.Sprint(realm))
}
// New constructs a new Driver with the given Azure Storage Account credentials
func New(accountName, accountKey, container, realm string) (*Driver, error) {
api, err := azure.NewClient(accountName, accountKey, realm, azure.DefaultAPIVersion, true)
params, err := NewParameters(parameters)
if err != nil {
return nil, err
}
return New(params)
}
blobClient := api.GetBlobService()
// Create registry container
containerRef := blobClient.GetContainerReference(container)
if _, err = containerRef.CreateIfNotExists(nil); err != nil {
// New constructs a new Driver from parameters
func New(params *Parameters) (*Driver, error) {
azClient, err := newAzureClient(params)
if err != nil {
return nil, err
}
client := azClient.ContainerClient()
d := &driver{
client: blobClient,
container: container,
}
azClient: azClient,
client: client,
rootDirectory: params.RootDirectory}
return &Driver{baseEmbed: baseEmbed{Base: base.Base{StorageDriver: d}}}, nil
}
@@ -104,17 +74,16 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
blob, err := blobRef.Get(nil)
downloadResponse, err := d.client.NewBlobClient(d.blobName(path)).DownloadStream(ctx, nil)
if err != nil {
if is404(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
defer blob.Close()
return io.ReadAll(blob)
body := downloadResponse.Body
defer body.Close()
return io.ReadAll(body)
}
// PutContent stores the []byte content at a location designated by "path".
@@ -137,75 +106,80 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
// losing the existing data while migrating it to BlockBlob type. However,
// expectation is the clients pushing will be retrying when they get an error
// response.
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
err := blobRef.GetProperties(nil)
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
props, err := blobRef.GetProperties(ctx, nil)
if err != nil && !is404(err) {
return fmt.Errorf("failed to get blob properties: %v", err)
}
if err == nil && blobRef.Properties.BlobType != azure.BlobTypeBlock {
if err := blobRef.Delete(nil); err != nil {
return fmt.Errorf("failed to delete legacy blob (%s): %v", blobRef.Properties.BlobType, err)
if err == nil && props.BlobType != nil && *props.BlobType != blob.BlobTypeBlockBlob {
if _, err := blobRef.Delete(ctx, nil); err != nil {
return fmt.Errorf("failed to delete legacy blob (%v): %v", *props.BlobType, err)
}
}
r := bytes.NewReader(contents)
// reset properties to empty before doing overwrite
blobRef.Properties = azure.BlobProperties{}
return blobRef.CreateBlockBlobFromReader(r, nil)
_, err = d.client.NewBlockBlobClient(blobName).UploadBuffer(ctx, contents, nil)
return err
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
if ok, err := blobRef.Exists(); err != nil {
return nil, err
} else if !ok {
return nil, storagedriver.PathNotFoundError{Path: path}
blobRef := d.client.NewBlobClient(d.blobName(path))
options := blob.DownloadStreamOptions{
Range: blob.HTTPRange{
Offset: offset,
},
}
err := blobRef.GetProperties(nil)
props, err := blobRef.GetProperties(ctx, nil)
if err != nil {
return nil, err
if is404(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, fmt.Errorf("failed to get blob properties: %v", err)
}
info := blobRef.Properties
size := info.ContentLength
if props.ContentLength == nil {
return nil, fmt.Errorf("failed to get ContentLength for path: %s", path)
}
size := *props.ContentLength
if offset >= size {
return io.NopCloser(bytes.NewReader(nil)), nil
}
resp, err := blobRef.GetRange(&azure.GetBlobRangeOptions{
Range: &azure.BlobRange{
Start: uint64(offset),
End: 0,
},
})
resp, err := blobRef.DownloadStream(ctx, &options)
if err != nil {
if is404(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
return resp, nil
return resp.Body, nil
}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
blobExists, err := blobRef.Exists()
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
props, err := blobRef.GetProperties(ctx, nil)
blobExists := true
if err != nil {
return nil, err
if !is404(err) {
return nil, err
}
blobExists = false
}
var size int64
if blobExists {
if append {
err = blobRef.GetProperties(nil)
if err != nil {
return nil, err
if props.ContentLength == nil {
return nil, fmt.Errorf("cannot append to blob because no ContentLength property was returned for: %s", blobName)
}
blobProperties := blobRef.Properties
size = blobProperties.ContentLength
size = *props.ContentLength
} else {
err = blobRef.Delete(nil)
if err != nil {
if _, err := blobRef.Delete(ctx, nil); err != nil {
return nil, err
}
}
@@ -213,57 +187,67 @@ func (d *driver) Writer(ctx context.Context, path string, append bool) (storaged
if append {
return nil, storagedriver.PathNotFoundError{Path: path}
}
err = blobRef.PutAppendBlob(nil)
if err != nil {
if _, err = d.client.NewAppendBlobClient(blobName).Create(ctx, nil); err != nil {
return nil, err
}
}
return d.newWriter(path, size), nil
return d.newWriter(ctx, blobName, size), nil
}
// Stat retrieves the FileInfo for the given path, including the current size
// in bytes and the creation time.
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
// Check if the path is a blob
if ok, err := blobRef.Exists(); err != nil {
props, err := blobRef.GetProperties(ctx, nil)
if err != nil && !is404(err) {
return nil, err
} else if ok {
err = blobRef.GetProperties(nil)
if err != nil {
return nil, err
}
if err == nil {
var missing []string
if props.ContentLength == nil {
missing = append(missing, "ContentLength")
}
if props.LastModified == nil {
missing = append(missing, "LastModified")
}
blobProperties := blobRef.Properties
if len(missing) > 0 {
return nil, fmt.Errorf("required blob properties %s are missing for blob: %s", missing, blobName)
}
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
Size: blobProperties.ContentLength,
ModTime: time.Time(blobProperties.LastModified),
Size: *props.ContentLength,
ModTime: *props.LastModified,
IsDir: false,
}}, nil
}
// Check if path is a virtual container
virtContainerPath := path
virtContainerPath := blobName
if !strings.HasSuffix(virtContainerPath, "/") {
virtContainerPath += "/"
}
containerRef := d.client.GetContainerReference(d.container)
blobs, err := containerRef.ListBlobs(azure.ListBlobsParameters{
Prefix: virtContainerPath,
MaxResults: 1,
maxResults := int32(1)
pager := d.client.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{
MaxResults: &maxResults,
Prefix: &virtContainerPath,
})
if err != nil {
return nil, err
}
if len(blobs.Blobs) > 0 {
// path is a virtual container
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
IsDir: true,
}}, nil
for pager.More() {
resp, err := pager.NextPage(ctx)
if err != nil {
return nil, err
}
if len(resp.Segment.BlobItems) > 0 {
// path is a virtual container
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
IsDir: true,
}}, nil
}
}
// path is not a blob or virtual container
@@ -277,7 +261,7 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
path = ""
}
blobs, err := d.listBlobs(d.container, path)
blobs, err := d.listBlobs(ctx, path)
if err != nil {
return blobs, err
}
@@ -292,10 +276,12 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
srcBlobRef := d.client.GetContainerReference(d.container).GetBlobReference(sourcePath)
sourceBlobURL := srcBlobRef.GetURL()
destBlobRef := d.client.GetContainerReference(d.container).GetBlobReference(destPath)
err := destBlobRef.Copy(sourceBlobURL, nil)
sourceBlobURL, err := d.URLFor(ctx, sourcePath, nil)
if err != nil {
return err
}
destBlobRef := d.client.NewBlockBlobClient(d.blobName(destPath))
_, err = destBlobRef.CopyFromURL(ctx, sourceBlobURL, nil)
if err != nil {
if is404(err) {
return storagedriver.PathNotFoundError{Path: sourcePath}
@@ -303,29 +289,30 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e
return err
}
return srcBlobRef.Delete(nil)
_, err = d.client.NewBlobClient(d.blobName(sourcePath)).Delete(ctx, nil)
return err
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, path string) error {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
ok, err := blobRef.DeleteIfExists(nil)
if err != nil {
blobRef := d.client.NewBlobClient(d.blobName(path))
_, err := blobRef.Delete(ctx, nil)
if err == nil {
// was a blob and deleted, return
return nil
} else if !is404(err) {
return err
}
if ok {
return nil // was a blob and deleted, return
}
// Not a blob, see if path is a virtual container with blobs
blobs, err := d.listBlobs(d.container, path)
blobs, err := d.listBlobs(ctx, path)
if err != nil {
return err
}
for _, b := range blobs {
blobRef = d.client.GetContainerReference(d.container).GetBlobReference(b)
if err = blobRef.Delete(nil); err != nil {
blobRef := d.client.NewBlobClient(d.blobName(b))
if _, err := blobRef.Delete(ctx, nil); err != nil {
return err
}
}
@@ -348,15 +335,9 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
expiresTime = t
}
}
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
return blobRef.GetSASURI(azure.BlobSASOptions{
BlobServiceSASPermissions: azure.BlobServiceSASPermissions{
Read: true,
},
SASOptions: azure.SASOptions{
Expiry: expiresTime,
},
})
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
return d.azClient.SignBlobURL(ctx, blobRef.URL(), expiresTime)
}
// Walk traverses a filesystem defined within driver, starting
@@ -399,38 +380,51 @@ func directDescendants(blobs []string, prefix string) []string {
return keys
}
func (d *driver) listBlobs(container, virtPath string) ([]string, error) {
func (d *driver) listBlobs(ctx context.Context, virtPath string) ([]string, error) {
if virtPath != "" && !strings.HasSuffix(virtPath, "/") { // containerify the path
virtPath += "/"
}
out := []string{}
marker := ""
containerRef := d.client.GetContainerReference(d.container)
for {
resp, err := containerRef.ListBlobs(azure.ListBlobsParameters{
Marker: marker,
Prefix: virtPath,
})
if err != nil {
return out, err
}
// we will replace the root directory prefix before returning blob names
blobPrefix := d.blobName("")
for _, b := range resp.Blobs {
out = append(out, b.Name)
}
if len(resp.Blobs) == 0 || resp.NextMarker == "" {
break
}
marker = resp.NextMarker
// This is to cover for the cases when the rootDirectory of the driver is either "" or "/".
// In those cases, there is no root prefix to replace and we must actually add a "/" to all
// results in order to keep them as valid paths as recognized by storagedriver.PathRegexp
prefix := ""
if blobPrefix == "" {
prefix = "/"
}
out := []string{}
listPrefix := d.blobName(virtPath)
pager := d.client.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{
Prefix: &listPrefix,
})
for pager.More() {
resp, err := pager.NextPage(ctx)
if err != nil {
return nil, err
}
for _, blob := range resp.Segment.BlobItems {
if blob.Name == nil {
return nil, fmt.Errorf("required blob property Name is missing while listing blobs under: %s", listPrefix)
}
name := *blob.Name
out = append(out, strings.Replace(name, blobPrefix, prefix, 1))
}
}
return out, nil
}
func (d *driver) blobName(path string) string {
return strings.TrimLeft(strings.TrimRight(d.rootDirectory, "/")+path, "/")
}
func is404(err error) bool {
statusCodeErr, ok := err.(azure.AzureStorageServiceError)
return ok && statusCodeErr.StatusCode == http.StatusNotFound
return bloberror.HasCode(err, bloberror.BlobNotFound, bloberror.ContainerNotFound, bloberror.ResourceNotFound)
}
type writer struct {
@@ -443,15 +437,15 @@ type writer struct {
cancelled bool
}
func (d *driver) newWriter(path string, size int64) storagedriver.FileWriter {
func (d *driver) newWriter(ctx context.Context, path string, size int64) storagedriver.FileWriter {
return &writer{
driver: d,
path: path,
size: size,
bw: bufio.NewWriterSize(&blockWriter{
client: d.client,
container: d.container,
path: path,
ctx: ctx,
client: d.client,
path: path,
}, maxChunkSize),
}
}
@@ -482,15 +476,16 @@ func (w *writer) Close() error {
return w.bw.Flush()
}
func (w *writer) Cancel() error {
func (w *writer) Cancel(ctx context.Context) error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
blobRef := w.driver.client.GetContainerReference(w.driver.container).GetBlobReference(w.path)
return blobRef.Delete(nil)
blobRef := w.driver.client.NewBlobClient(w.path)
_, err := blobRef.Delete(ctx, nil)
return err
}
func (w *writer) Commit() error {
@@ -506,26 +501,18 @@ func (w *writer) Commit() error {
}
type blockWriter struct {
client azure.BlobStorageClient
container string
path string
// We construct transient blockWriter objects to encapsulate a write
// and need to keep the context passed in to the original FileWriter.Write
ctx context.Context
client *container.Client
path string
}
func (bw *blockWriter) Write(p []byte) (int, error) {
n := 0
blobRef := bw.client.GetContainerReference(bw.container).GetBlobReference(bw.path)
for offset := 0; offset < len(p); offset += maxChunkSize {
chunkSize := maxChunkSize
if offset+chunkSize > len(p) {
chunkSize = len(p) - offset
}
err := blobRef.AppendBlock(p[offset:offset+chunkSize], nil)
if err != nil {
return n, err
}
n += chunkSize
blobRef := bw.client.NewAppendBlobClient(bw.path)
_, err := blobRef.AppendBlock(bw.ctx, streaming.NopCloser(bytes.NewReader(p)), nil)
if err != nil {
return 0, err
}
return n, nil
return len(p), nil
}

View File

@@ -0,0 +1,152 @@
package azure
import (
"context"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service"
)
const (
UDCGracePeriod = 30.0 * time.Minute
UDCExpiryTime = 48.0 * time.Hour
)
// signer abstracts the specifics of a blob SAS and is specialized
// for the different authentication credentials
type signer interface {
Sign(context.Context, *sas.BlobSignatureValues) (sas.QueryParameters, error)
}
type sharedKeySigner struct {
cred *azblob.SharedKeyCredential
}
type clientTokenSigner struct {
client *azblob.Client
cred azcore.TokenCredential
udcMutex sync.Mutex
udc *service.UserDelegationCredential
udcExpiry time.Time
}
// azureClient abstracts signing blob urls for a container since the
// azure apis have completely different underlying authentication apis
type azureClient struct {
container string
client *azblob.Client
signer signer
}
func newAzureClient(params *Parameters) (*azureClient, error) {
if params.AccountKey != "" {
cred, err := azblob.NewSharedKeyCredential(params.AccountName, params.AccountKey)
if err != nil {
return nil, err
}
client, err := azblob.NewClientWithSharedKeyCredential(params.ServiceURL, cred, nil)
if err != nil {
return nil, err
}
signer := &sharedKeySigner{
cred: cred,
}
return &azureClient{
container: params.Container,
client: client,
signer: signer,
}, nil
}
var cred azcore.TokenCredential
var err error
if params.Credentials.Type == "client_secret" {
creds := &params.Credentials
if cred, err = azidentity.NewClientSecretCredential(creds.TenantID, creds.ClientID, creds.Secret, nil); err != nil {
return nil, err
}
} else if cred, err = azidentity.NewDefaultAzureCredential(nil); err != nil {
return nil, err
}
client, err := azblob.NewClient(params.ServiceURL, cred, nil)
if err != nil {
return nil, err
}
signer := &clientTokenSigner{
client: client,
cred: cred,
}
return &azureClient{
container: params.Container,
client: client,
signer: signer,
}, nil
}
func (a *azureClient) ContainerClient() *container.Client {
return a.client.ServiceClient().NewContainerClient(a.container)
}
func (a *azureClient) SignBlobURL(ctx context.Context, blobURL string, expires time.Time) (string, error) {
urlParts, err := sas.ParseURL(blobURL)
if err != nil {
return "", err
}
perms := sas.BlobPermissions{Read: true}
signatureValues := sas.BlobSignatureValues{
Protocol: sas.ProtocolHTTPS,
StartTime: time.Now().UTC().Add(-10 * time.Second),
ExpiryTime: expires,
Permissions: perms.String(),
ContainerName: urlParts.ContainerName,
BlobName: urlParts.BlobName,
}
urlParts.SAS, err = a.signer.Sign(ctx, &signatureValues)
if err != nil {
return "", err
}
return urlParts.String(), nil
}
func (s *sharedKeySigner) Sign(ctx context.Context, signatureValues *sas.BlobSignatureValues) (sas.QueryParameters, error) {
return signatureValues.SignWithSharedKey(s.cred)
}
func (s *clientTokenSigner) refreshUDC(ctx context.Context) (*service.UserDelegationCredential, error) {
s.udcMutex.Lock()
defer s.udcMutex.Unlock()
now := time.Now().UTC()
if s.udc == nil || s.udcExpiry.Sub(now) < UDCGracePeriod {
// reissue user delegation credential
startTime := now.Add(-10 * time.Second)
expiryTime := startTime.Add(UDCExpiryTime)
info := service.KeyInfo{
Start: to.Ptr(startTime.UTC().Format(sas.TimeFormat)),
Expiry: to.Ptr(expiryTime.UTC().Format(sas.TimeFormat)),
}
udc, err := s.client.ServiceClient().GetUserDelegationCredential(ctx, info, nil)
if err != nil {
return nil, err
}
s.udc = udc
s.udcExpiry = expiryTime
}
return s.udc, nil
}
func (s *clientTokenSigner) Sign(ctx context.Context, signatureValues *sas.BlobSignatureValues) (sas.QueryParameters, error) {
udc, err := s.refreshUDC(ctx)
if err != nil {
return sas.QueryParameters{}, err
}
return signatureValues.SignWithUserDelegation(udc)
}

View File

@@ -12,10 +12,11 @@ import (
)
const (
envAccountName = "AZURE_STORAGE_ACCOUNT_NAME"
envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY"
envContainer = "AZURE_STORAGE_CONTAINER"
envRealm = "AZURE_STORAGE_REALM"
envAccountName = "AZURE_STORAGE_ACCOUNT_NAME"
envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY"
envContainer = "AZURE_STORAGE_CONTAINER"
envRealm = "AZURE_STORAGE_REALM"
envRootDirectory = "AZURE_ROOT_DIRECTORY"
)
// Hook up gocheck into the "go test" runner.
@@ -23,32 +24,42 @@ func Test(t *testing.T) { TestingT(t) }
func init() {
var (
accountName string
accountKey string
container string
realm string
accountName string
accountKey string
container string
realm string
rootDirectory string
)
config := []struct {
env string
value *string
env string
value *string
missingOk bool
}{
{envAccountName, &accountName},
{envAccountKey, &accountKey},
{envContainer, &container},
{envRealm, &realm},
{envAccountName, &accountName, false},
{envAccountKey, &accountKey, false},
{envContainer, &container, false},
{envRealm, &realm, false},
{envRootDirectory, &rootDirectory, true},
}
missing := []string{}
for _, v := range config {
*v.value = os.Getenv(v.env)
if *v.value == "" {
if *v.value == "" && !v.missingOk {
missing = append(missing, v.env)
}
}
azureDriverConstructor := func() (storagedriver.StorageDriver, error) {
return New(accountName, accountKey, container, realm)
params := Parameters{
Container: container,
AccountName: accountName,
AccountKey: accountKey,
Realm: realm,
RootDirectory: rootDirectory,
}
return New(&params)
}
// Skip Azure storage driver tests if environment variable parameters are not provided
@@ -61,3 +72,44 @@ func init() {
testsuites.RegisterSuite(azureDriverConstructor, skipCheck)
}
func TestParamParsing(t *testing.T) {
expectErrors := []map[string]interface{}{
{},
{"accountname": "acc1"},
}
for _, parameters := range expectErrors {
if _, err := NewParameters(parameters); err == nil {
t.Fatalf("Expected an error for parameter set: %v", parameters)
}
}
input := []map[string]interface{}{
{"accountname": "acc1", "accountkey": "k1", "container": "c1"},
{"accountname": "acc1", "container": "c1", "credentials": map[string]interface{}{"type": "default"}},
{"accountname": "acc1", "container": "c1", "credentials": map[string]interface{}{"type": "client_secret", "clientid": "c1", "tenantid": "t1", "secret": "s1"}},
}
expecteds := []Parameters{
{
Container: "c1", AccountName: "acc1", AccountKey: "k1",
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
},
{
Container: "c1", AccountName: "acc1", Credentials: Credentials{Type: "default"},
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
},
{
Container: "c1", AccountName: "acc1",
Credentials: Credentials{Type: "client_secret", ClientID: "c1", TenantID: "t1", Secret: "s1"},
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
},
}
for i, expected := range expecteds {
actual, err := NewParameters(input[i])
if err != nil {
t.Fatalf("Failed to parse: %v", input[i])
}
if *actual != expected {
t.Fatalf("Expected: %v != %v", *actual, expected)
}
}
}

View File

@@ -0,0 +1,49 @@
package azure
import (
"errors"
"fmt"
"github.com/mitchellh/mapstructure"
)
const (
defaultRealm = "core.windows.net"
)
type Credentials struct {
Type string `mapstructure:"type"`
ClientID string `mapstructure:"clientid"`
TenantID string `mapstructure:"tenantid"`
Secret string `mapstructure:"secret"`
}
type Parameters struct {
Container string `mapstructure:"container"`
AccountName string `mapstructure:"accountname"`
AccountKey string `mapstructure:"accountkey"`
Credentials Credentials `mapstructure:"credentials"`
ConnectionString string `mapstructure:"connectionstring"`
Realm string `mapstructure:"realm"`
RootDirectory string `mapstructure:"rootdirectory"`
ServiceURL string `mapstructure:"serviceurl"`
}
func NewParameters(parameters map[string]interface{}) (*Parameters, error) {
params := Parameters{
Realm: defaultRealm,
}
if err := mapstructure.Decode(parameters, &params); err != nil {
return nil, err
}
if params.AccountName == "" {
return nil, errors.New("no accountname parameter provided")
}
if params.Container == "" {
return nil, errors.New("no container parameter provider")
}
if params.ServiceURL == "" {
params.ServiceURL = fmt.Sprintf("https://%s.blob.%s", params.AccountName, params.Realm)
}
return &params, nil
}

View File

@@ -139,7 +139,7 @@ func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte
defer writer.Close()
_, err = io.Copy(writer, bytes.NewReader(contents))
if err != nil {
writer.Cancel()
writer.Cancel(ctx)
return err
}
return writer.Commit()
@@ -387,7 +387,7 @@ func (fw *fileWriter) Close() error {
return nil
}
func (fw *fileWriter) Cancel() error {
func (fw *fileWriter) Cancel(ctx context.Context) error {
if fw.closed {
return fmt.Errorf("already closed")
}

View File

@@ -293,7 +293,7 @@ func (w *writer) Close() error {
return nil
}
func (w *writer) Cancel() error {
func (w *writer) Cancel(ctx context.Context) error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {

View File

@@ -1425,7 +1425,7 @@ func (w *writer) Close() error {
return w.flushPart()
}
func (w *writer) Cancel() error {
func (w *writer) Cancel(ctx context.Context) error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {

View File

@@ -103,7 +103,7 @@ type FileWriter interface {
Size() int64
// Cancel removes any written content from this FileWriter.
Cancel() error
Cancel(context.Context) error
// Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and

View File

@@ -850,14 +850,14 @@ func (w *writer) Close() error {
return nil
}
func (w *writer) Cancel() error {
func (w *writer) Cancel(ctx context.Context) error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
return w.driver.Delete(context.Background(), w.path)
return w.driver.Delete(ctx, w.path)
}
func (w *writer) Commit() error {

View File

@@ -61,9 +61,9 @@ func (tfw *testFileWriter) Close() error {
return tfw.FileWriter.Close()
}
func (tfw *testFileWriter) Cancel() error {
func (tfw *testFileWriter) Cancel(ctx context.Context) error {
tfw.Write(nil)
return tfw.FileWriter.Cancel()
return tfw.FileWriter.Cancel(ctx)
}
func (tfw *testFileWriter) Commit() error {