Update Azure SDK and support additional authentication schemes

Microsoft has updated the golang Azure SDK significantly.  Update the
azure storage driver to use the new SDK.  Add support for client
secret and MSI authentication schemes in addition to shared key
authentication.

Implement rootDirectory support for the azure storage driver to mirror
the S3 driver.

Signed-off-by: Kirat Singh <kirat.singh@beacon.io>

Co-authored-by: Cory Snider <corhere@gmail.com>
This commit is contained in:
Kirat Singh
2020-02-21 03:58:17 +00:00
parent e5d5810851
commit ba4a6bbe02
365 changed files with 44060 additions and 21016 deletions

View File

@@ -8,7 +8,6 @@ import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
@@ -16,22 +15,22 @@ import (
"github.com/distribution/distribution/v3/registry/storage/driver/base"
"github.com/distribution/distribution/v3/registry/storage/driver/factory"
azure "github.com/Azure/azure-sdk-for-go/storage"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
)
const driverName = "azure"
const (
paramAccountName = "accountname"
paramAccountKey = "accountkey"
paramContainer = "container"
paramRealm = "realm"
maxChunkSize = 4 * 1024 * 1024
maxChunkSize = 4 * 1024 * 1024
)
type driver struct {
client azure.BlobStorageClient
container string
azClient *azureClient
client *container.Client
rootDirectory string
}
type baseEmbed struct{ base.Base }
@@ -47,53 +46,24 @@ func init() {
type azureDriverFactory struct{}
func (factory *azureDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
// FromParameters constructs a new Driver with a given parameters map.
func FromParameters(parameters map[string]interface{}) (*Driver, error) {
accountName, ok := parameters[paramAccountName]
if !ok || fmt.Sprint(accountName) == "" {
return nil, fmt.Errorf("no %s parameter provided", paramAccountName)
}
accountKey, ok := parameters[paramAccountKey]
if !ok || fmt.Sprint(accountKey) == "" {
return nil, fmt.Errorf("no %s parameter provided", paramAccountKey)
}
container, ok := parameters[paramContainer]
if !ok || fmt.Sprint(container) == "" {
return nil, fmt.Errorf("no %s parameter provided", paramContainer)
}
realm, ok := parameters[paramRealm]
if !ok || fmt.Sprint(realm) == "" {
realm = azure.DefaultBaseURL
}
return New(fmt.Sprint(accountName), fmt.Sprint(accountKey), fmt.Sprint(container), fmt.Sprint(realm))
}
// New constructs a new Driver with the given Azure Storage Account credentials
func New(accountName, accountKey, container, realm string) (*Driver, error) {
api, err := azure.NewClient(accountName, accountKey, realm, azure.DefaultAPIVersion, true)
params, err := NewParameters(parameters)
if err != nil {
return nil, err
}
return New(params)
}
blobClient := api.GetBlobService()
// Create registry container
containerRef := blobClient.GetContainerReference(container)
if _, err = containerRef.CreateIfNotExists(nil); err != nil {
// New constructs a new Driver from parameters
func New(params *Parameters) (*Driver, error) {
azClient, err := newAzureClient(params)
if err != nil {
return nil, err
}
client := azClient.ContainerClient()
d := &driver{
client: blobClient,
container: container,
}
azClient: azClient,
client: client,
rootDirectory: params.RootDirectory}
return &Driver{baseEmbed: baseEmbed{Base: base.Base{StorageDriver: d}}}, nil
}
@@ -104,17 +74,16 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
blob, err := blobRef.Get(nil)
downloadResponse, err := d.client.NewBlobClient(d.blobName(path)).DownloadStream(ctx, nil)
if err != nil {
if is404(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
defer blob.Close()
return io.ReadAll(blob)
body := downloadResponse.Body
defer body.Close()
return io.ReadAll(body)
}
// PutContent stores the []byte content at a location designated by "path".
@@ -137,75 +106,80 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
// losing the existing data while migrating it to BlockBlob type. However,
// expectation is the clients pushing will be retrying when they get an error
// response.
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
err := blobRef.GetProperties(nil)
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
props, err := blobRef.GetProperties(ctx, nil)
if err != nil && !is404(err) {
return fmt.Errorf("failed to get blob properties: %v", err)
}
if err == nil && blobRef.Properties.BlobType != azure.BlobTypeBlock {
if err := blobRef.Delete(nil); err != nil {
return fmt.Errorf("failed to delete legacy blob (%s): %v", blobRef.Properties.BlobType, err)
if err == nil && props.BlobType != nil && *props.BlobType != blob.BlobTypeBlockBlob {
if _, err := blobRef.Delete(ctx, nil); err != nil {
return fmt.Errorf("failed to delete legacy blob (%v): %v", *props.BlobType, err)
}
}
r := bytes.NewReader(contents)
// reset properties to empty before doing overwrite
blobRef.Properties = azure.BlobProperties{}
return blobRef.CreateBlockBlobFromReader(r, nil)
_, err = d.client.NewBlockBlobClient(blobName).UploadBuffer(ctx, contents, nil)
return err
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
if ok, err := blobRef.Exists(); err != nil {
return nil, err
} else if !ok {
return nil, storagedriver.PathNotFoundError{Path: path}
blobRef := d.client.NewBlobClient(d.blobName(path))
options := blob.DownloadStreamOptions{
Range: blob.HTTPRange{
Offset: offset,
},
}
err := blobRef.GetProperties(nil)
props, err := blobRef.GetProperties(ctx, nil)
if err != nil {
return nil, err
if is404(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, fmt.Errorf("failed to get blob properties: %v", err)
}
info := blobRef.Properties
size := info.ContentLength
if props.ContentLength == nil {
return nil, fmt.Errorf("failed to get ContentLength for path: %s", path)
}
size := *props.ContentLength
if offset >= size {
return io.NopCloser(bytes.NewReader(nil)), nil
}
resp, err := blobRef.GetRange(&azure.GetBlobRangeOptions{
Range: &azure.BlobRange{
Start: uint64(offset),
End: 0,
},
})
resp, err := blobRef.DownloadStream(ctx, &options)
if err != nil {
if is404(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
return resp, nil
return resp.Body, nil
}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
blobExists, err := blobRef.Exists()
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
props, err := blobRef.GetProperties(ctx, nil)
blobExists := true
if err != nil {
return nil, err
if !is404(err) {
return nil, err
}
blobExists = false
}
var size int64
if blobExists {
if append {
err = blobRef.GetProperties(nil)
if err != nil {
return nil, err
if props.ContentLength == nil {
return nil, fmt.Errorf("cannot append to blob because no ContentLength property was returned for: %s", blobName)
}
blobProperties := blobRef.Properties
size = blobProperties.ContentLength
size = *props.ContentLength
} else {
err = blobRef.Delete(nil)
if err != nil {
if _, err := blobRef.Delete(ctx, nil); err != nil {
return nil, err
}
}
@@ -213,57 +187,67 @@ func (d *driver) Writer(ctx context.Context, path string, append bool) (storaged
if append {
return nil, storagedriver.PathNotFoundError{Path: path}
}
err = blobRef.PutAppendBlob(nil)
if err != nil {
if _, err = d.client.NewAppendBlobClient(blobName).Create(ctx, nil); err != nil {
return nil, err
}
}
return d.newWriter(path, size), nil
return d.newWriter(ctx, blobName, size), nil
}
// Stat retrieves the FileInfo for the given path, including the current size
// in bytes and the creation time.
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
// Check if the path is a blob
if ok, err := blobRef.Exists(); err != nil {
props, err := blobRef.GetProperties(ctx, nil)
if err != nil && !is404(err) {
return nil, err
} else if ok {
err = blobRef.GetProperties(nil)
if err != nil {
return nil, err
}
if err == nil {
var missing []string
if props.ContentLength == nil {
missing = append(missing, "ContentLength")
}
if props.LastModified == nil {
missing = append(missing, "LastModified")
}
blobProperties := blobRef.Properties
if len(missing) > 0 {
return nil, fmt.Errorf("required blob properties %s are missing for blob: %s", missing, blobName)
}
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
Size: blobProperties.ContentLength,
ModTime: time.Time(blobProperties.LastModified),
Size: *props.ContentLength,
ModTime: *props.LastModified,
IsDir: false,
}}, nil
}
// Check if path is a virtual container
virtContainerPath := path
virtContainerPath := blobName
if !strings.HasSuffix(virtContainerPath, "/") {
virtContainerPath += "/"
}
containerRef := d.client.GetContainerReference(d.container)
blobs, err := containerRef.ListBlobs(azure.ListBlobsParameters{
Prefix: virtContainerPath,
MaxResults: 1,
maxResults := int32(1)
pager := d.client.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{
MaxResults: &maxResults,
Prefix: &virtContainerPath,
})
if err != nil {
return nil, err
}
if len(blobs.Blobs) > 0 {
// path is a virtual container
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
IsDir: true,
}}, nil
for pager.More() {
resp, err := pager.NextPage(ctx)
if err != nil {
return nil, err
}
if len(resp.Segment.BlobItems) > 0 {
// path is a virtual container
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
IsDir: true,
}}, nil
}
}
// path is not a blob or virtual container
@@ -277,7 +261,7 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
path = ""
}
blobs, err := d.listBlobs(d.container, path)
blobs, err := d.listBlobs(ctx, path)
if err != nil {
return blobs, err
}
@@ -292,10 +276,12 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
srcBlobRef := d.client.GetContainerReference(d.container).GetBlobReference(sourcePath)
sourceBlobURL := srcBlobRef.GetURL()
destBlobRef := d.client.GetContainerReference(d.container).GetBlobReference(destPath)
err := destBlobRef.Copy(sourceBlobURL, nil)
sourceBlobURL, err := d.URLFor(ctx, sourcePath, nil)
if err != nil {
return err
}
destBlobRef := d.client.NewBlockBlobClient(d.blobName(destPath))
_, err = destBlobRef.CopyFromURL(ctx, sourceBlobURL, nil)
if err != nil {
if is404(err) {
return storagedriver.PathNotFoundError{Path: sourcePath}
@@ -303,29 +289,30 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e
return err
}
return srcBlobRef.Delete(nil)
_, err = d.client.NewBlobClient(d.blobName(sourcePath)).Delete(ctx, nil)
return err
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, path string) error {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
ok, err := blobRef.DeleteIfExists(nil)
if err != nil {
blobRef := d.client.NewBlobClient(d.blobName(path))
_, err := blobRef.Delete(ctx, nil)
if err == nil {
// was a blob and deleted, return
return nil
} else if !is404(err) {
return err
}
if ok {
return nil // was a blob and deleted, return
}
// Not a blob, see if path is a virtual container with blobs
blobs, err := d.listBlobs(d.container, path)
blobs, err := d.listBlobs(ctx, path)
if err != nil {
return err
}
for _, b := range blobs {
blobRef = d.client.GetContainerReference(d.container).GetBlobReference(b)
if err = blobRef.Delete(nil); err != nil {
blobRef := d.client.NewBlobClient(d.blobName(b))
if _, err := blobRef.Delete(ctx, nil); err != nil {
return err
}
}
@@ -348,15 +335,9 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
expiresTime = t
}
}
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
return blobRef.GetSASURI(azure.BlobSASOptions{
BlobServiceSASPermissions: azure.BlobServiceSASPermissions{
Read: true,
},
SASOptions: azure.SASOptions{
Expiry: expiresTime,
},
})
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
return d.azClient.SignBlobURL(ctx, blobRef.URL(), expiresTime)
}
// Walk traverses a filesystem defined within driver, starting
@@ -399,38 +380,51 @@ func directDescendants(blobs []string, prefix string) []string {
return keys
}
func (d *driver) listBlobs(container, virtPath string) ([]string, error) {
func (d *driver) listBlobs(ctx context.Context, virtPath string) ([]string, error) {
if virtPath != "" && !strings.HasSuffix(virtPath, "/") { // containerify the path
virtPath += "/"
}
out := []string{}
marker := ""
containerRef := d.client.GetContainerReference(d.container)
for {
resp, err := containerRef.ListBlobs(azure.ListBlobsParameters{
Marker: marker,
Prefix: virtPath,
})
if err != nil {
return out, err
}
// we will replace the root directory prefix before returning blob names
blobPrefix := d.blobName("")
for _, b := range resp.Blobs {
out = append(out, b.Name)
}
if len(resp.Blobs) == 0 || resp.NextMarker == "" {
break
}
marker = resp.NextMarker
// This is to cover for the cases when the rootDirectory of the driver is either "" or "/".
// In those cases, there is no root prefix to replace and we must actually add a "/" to all
// results in order to keep them as valid paths as recognized by storagedriver.PathRegexp
prefix := ""
if blobPrefix == "" {
prefix = "/"
}
out := []string{}
listPrefix := d.blobName(virtPath)
pager := d.client.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{
Prefix: &listPrefix,
})
for pager.More() {
resp, err := pager.NextPage(ctx)
if err != nil {
return nil, err
}
for _, blob := range resp.Segment.BlobItems {
if blob.Name == nil {
return nil, fmt.Errorf("required blob property Name is missing while listing blobs under: %s", listPrefix)
}
name := *blob.Name
out = append(out, strings.Replace(name, blobPrefix, prefix, 1))
}
}
return out, nil
}
func (d *driver) blobName(path string) string {
return strings.TrimLeft(strings.TrimRight(d.rootDirectory, "/")+path, "/")
}
func is404(err error) bool {
statusCodeErr, ok := err.(azure.AzureStorageServiceError)
return ok && statusCodeErr.StatusCode == http.StatusNotFound
return bloberror.HasCode(err, bloberror.BlobNotFound, bloberror.ContainerNotFound, bloberror.ResourceNotFound)
}
type writer struct {
@@ -443,15 +437,15 @@ type writer struct {
cancelled bool
}
func (d *driver) newWriter(path string, size int64) storagedriver.FileWriter {
func (d *driver) newWriter(ctx context.Context, path string, size int64) storagedriver.FileWriter {
return &writer{
driver: d,
path: path,
size: size,
bw: bufio.NewWriterSize(&blockWriter{
client: d.client,
container: d.container,
path: path,
ctx: ctx,
client: d.client,
path: path,
}, maxChunkSize),
}
}
@@ -482,15 +476,16 @@ func (w *writer) Close() error {
return w.bw.Flush()
}
func (w *writer) Cancel() error {
func (w *writer) Cancel(ctx context.Context) error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
blobRef := w.driver.client.GetContainerReference(w.driver.container).GetBlobReference(w.path)
return blobRef.Delete(nil)
blobRef := w.driver.client.NewBlobClient(w.path)
_, err := blobRef.Delete(ctx, nil)
return err
}
func (w *writer) Commit() error {
@@ -506,26 +501,18 @@ func (w *writer) Commit() error {
}
type blockWriter struct {
client azure.BlobStorageClient
container string
path string
// We construct transient blockWriter objects to encapsulate a write
// and need to keep the context passed in to the original FileWriter.Write
ctx context.Context
client *container.Client
path string
}
func (bw *blockWriter) Write(p []byte) (int, error) {
n := 0
blobRef := bw.client.GetContainerReference(bw.container).GetBlobReference(bw.path)
for offset := 0; offset < len(p); offset += maxChunkSize {
chunkSize := maxChunkSize
if offset+chunkSize > len(p) {
chunkSize = len(p) - offset
}
err := blobRef.AppendBlock(p[offset:offset+chunkSize], nil)
if err != nil {
return n, err
}
n += chunkSize
blobRef := bw.client.NewAppendBlobClient(bw.path)
_, err := blobRef.AppendBlock(bw.ctx, streaming.NopCloser(bytes.NewReader(p)), nil)
if err != nil {
return 0, err
}
return n, nil
return len(p), nil
}

View File

@@ -0,0 +1,152 @@
package azure
import (
"context"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service"
)
const (
UDCGracePeriod = 30.0 * time.Minute
UDCExpiryTime = 48.0 * time.Hour
)
// signer abstracts the specifics of a blob SAS and is specialized
// for the different authentication credentials
type signer interface {
Sign(context.Context, *sas.BlobSignatureValues) (sas.QueryParameters, error)
}
type sharedKeySigner struct {
cred *azblob.SharedKeyCredential
}
type clientTokenSigner struct {
client *azblob.Client
cred azcore.TokenCredential
udcMutex sync.Mutex
udc *service.UserDelegationCredential
udcExpiry time.Time
}
// azureClient abstracts signing blob urls for a container since the
// azure apis have completely different underlying authentication apis
type azureClient struct {
container string
client *azblob.Client
signer signer
}
func newAzureClient(params *Parameters) (*azureClient, error) {
if params.AccountKey != "" {
cred, err := azblob.NewSharedKeyCredential(params.AccountName, params.AccountKey)
if err != nil {
return nil, err
}
client, err := azblob.NewClientWithSharedKeyCredential(params.ServiceURL, cred, nil)
if err != nil {
return nil, err
}
signer := &sharedKeySigner{
cred: cred,
}
return &azureClient{
container: params.Container,
client: client,
signer: signer,
}, nil
}
var cred azcore.TokenCredential
var err error
if params.Credentials.Type == "client_secret" {
creds := &params.Credentials
if cred, err = azidentity.NewClientSecretCredential(creds.TenantID, creds.ClientID, creds.Secret, nil); err != nil {
return nil, err
}
} else if cred, err = azidentity.NewDefaultAzureCredential(nil); err != nil {
return nil, err
}
client, err := azblob.NewClient(params.ServiceURL, cred, nil)
if err != nil {
return nil, err
}
signer := &clientTokenSigner{
client: client,
cred: cred,
}
return &azureClient{
container: params.Container,
client: client,
signer: signer,
}, nil
}
func (a *azureClient) ContainerClient() *container.Client {
return a.client.ServiceClient().NewContainerClient(a.container)
}
func (a *azureClient) SignBlobURL(ctx context.Context, blobURL string, expires time.Time) (string, error) {
urlParts, err := sas.ParseURL(blobURL)
if err != nil {
return "", err
}
perms := sas.BlobPermissions{Read: true}
signatureValues := sas.BlobSignatureValues{
Protocol: sas.ProtocolHTTPS,
StartTime: time.Now().UTC().Add(-10 * time.Second),
ExpiryTime: expires,
Permissions: perms.String(),
ContainerName: urlParts.ContainerName,
BlobName: urlParts.BlobName,
}
urlParts.SAS, err = a.signer.Sign(ctx, &signatureValues)
if err != nil {
return "", err
}
return urlParts.String(), nil
}
func (s *sharedKeySigner) Sign(ctx context.Context, signatureValues *sas.BlobSignatureValues) (sas.QueryParameters, error) {
return signatureValues.SignWithSharedKey(s.cred)
}
func (s *clientTokenSigner) refreshUDC(ctx context.Context) (*service.UserDelegationCredential, error) {
s.udcMutex.Lock()
defer s.udcMutex.Unlock()
now := time.Now().UTC()
if s.udc == nil || s.udcExpiry.Sub(now) < UDCGracePeriod {
// reissue user delegation credential
startTime := now.Add(-10 * time.Second)
expiryTime := startTime.Add(UDCExpiryTime)
info := service.KeyInfo{
Start: to.Ptr(startTime.UTC().Format(sas.TimeFormat)),
Expiry: to.Ptr(expiryTime.UTC().Format(sas.TimeFormat)),
}
udc, err := s.client.ServiceClient().GetUserDelegationCredential(ctx, info, nil)
if err != nil {
return nil, err
}
s.udc = udc
s.udcExpiry = expiryTime
}
return s.udc, nil
}
func (s *clientTokenSigner) Sign(ctx context.Context, signatureValues *sas.BlobSignatureValues) (sas.QueryParameters, error) {
udc, err := s.refreshUDC(ctx)
if err != nil {
return sas.QueryParameters{}, err
}
return signatureValues.SignWithUserDelegation(udc)
}

View File

@@ -12,10 +12,11 @@ import (
)
const (
envAccountName = "AZURE_STORAGE_ACCOUNT_NAME"
envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY"
envContainer = "AZURE_STORAGE_CONTAINER"
envRealm = "AZURE_STORAGE_REALM"
envAccountName = "AZURE_STORAGE_ACCOUNT_NAME"
envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY"
envContainer = "AZURE_STORAGE_CONTAINER"
envRealm = "AZURE_STORAGE_REALM"
envRootDirectory = "AZURE_ROOT_DIRECTORY"
)
// Hook up gocheck into the "go test" runner.
@@ -23,32 +24,42 @@ func Test(t *testing.T) { TestingT(t) }
func init() {
var (
accountName string
accountKey string
container string
realm string
accountName string
accountKey string
container string
realm string
rootDirectory string
)
config := []struct {
env string
value *string
env string
value *string
missingOk bool
}{
{envAccountName, &accountName},
{envAccountKey, &accountKey},
{envContainer, &container},
{envRealm, &realm},
{envAccountName, &accountName, false},
{envAccountKey, &accountKey, false},
{envContainer, &container, false},
{envRealm, &realm, false},
{envRootDirectory, &rootDirectory, true},
}
missing := []string{}
for _, v := range config {
*v.value = os.Getenv(v.env)
if *v.value == "" {
if *v.value == "" && !v.missingOk {
missing = append(missing, v.env)
}
}
azureDriverConstructor := func() (storagedriver.StorageDriver, error) {
return New(accountName, accountKey, container, realm)
params := Parameters{
Container: container,
AccountName: accountName,
AccountKey: accountKey,
Realm: realm,
RootDirectory: rootDirectory,
}
return New(&params)
}
// Skip Azure storage driver tests if environment variable parameters are not provided
@@ -61,3 +72,44 @@ func init() {
testsuites.RegisterSuite(azureDriverConstructor, skipCheck)
}
func TestParamParsing(t *testing.T) {
expectErrors := []map[string]interface{}{
{},
{"accountname": "acc1"},
}
for _, parameters := range expectErrors {
if _, err := NewParameters(parameters); err == nil {
t.Fatalf("Expected an error for parameter set: %v", parameters)
}
}
input := []map[string]interface{}{
{"accountname": "acc1", "accountkey": "k1", "container": "c1"},
{"accountname": "acc1", "container": "c1", "credentials": map[string]interface{}{"type": "default"}},
{"accountname": "acc1", "container": "c1", "credentials": map[string]interface{}{"type": "client_secret", "clientid": "c1", "tenantid": "t1", "secret": "s1"}},
}
expecteds := []Parameters{
{
Container: "c1", AccountName: "acc1", AccountKey: "k1",
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
},
{
Container: "c1", AccountName: "acc1", Credentials: Credentials{Type: "default"},
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
},
{
Container: "c1", AccountName: "acc1",
Credentials: Credentials{Type: "client_secret", ClientID: "c1", TenantID: "t1", Secret: "s1"},
Realm: "core.windows.net", ServiceURL: "https://acc1.blob.core.windows.net",
},
}
for i, expected := range expecteds {
actual, err := NewParameters(input[i])
if err != nil {
t.Fatalf("Failed to parse: %v", input[i])
}
if *actual != expected {
t.Fatalf("Expected: %v != %v", *actual, expected)
}
}
}

View File

@@ -0,0 +1,49 @@
package azure
import (
"errors"
"fmt"
"github.com/mitchellh/mapstructure"
)
const (
defaultRealm = "core.windows.net"
)
type Credentials struct {
Type string `mapstructure:"type"`
ClientID string `mapstructure:"clientid"`
TenantID string `mapstructure:"tenantid"`
Secret string `mapstructure:"secret"`
}
type Parameters struct {
Container string `mapstructure:"container"`
AccountName string `mapstructure:"accountname"`
AccountKey string `mapstructure:"accountkey"`
Credentials Credentials `mapstructure:"credentials"`
ConnectionString string `mapstructure:"connectionstring"`
Realm string `mapstructure:"realm"`
RootDirectory string `mapstructure:"rootdirectory"`
ServiceURL string `mapstructure:"serviceurl"`
}
func NewParameters(parameters map[string]interface{}) (*Parameters, error) {
params := Parameters{
Realm: defaultRealm,
}
if err := mapstructure.Decode(parameters, &params); err != nil {
return nil, err
}
if params.AccountName == "" {
return nil, errors.New("no accountname parameter provided")
}
if params.Container == "" {
return nil, errors.New("no container parameter provider")
}
if params.ServiceURL == "" {
params.ServiceURL = fmt.Sprintf("https://%s.blob.%s", params.AccountName, params.Realm)
}
return &params, nil
}