From 723dd7c670d01de5bf8016e31c61a3508a0ac5b7 Mon Sep 17 00:00:00 2001
From: Gaurav Mehta <gaurav@rancher.com>
Date: Mon, 5 Oct 2020 14:15:57 +1100
Subject: [PATCH] Initial commit for adding ecr credential plugin

---
 cluster/validation.go          | 21 ++++++++
 docker/docker.go               | 45 ++++++++++++++---
 go.mod                         |  1 +
 go.sum                         |  6 +++
 types/rke_types.go             |  2 +
 types/zz_generated_deepcopy.go | 11 ++++-
 util/ecr.go                    | 89 ++++++++++++++++++++++++++++++++++
 7 files changed, 166 insertions(+), 9 deletions(-)
 create mode 100644 util/ecr.go

diff --git a/cluster/validation.go b/cluster/validation.go
index 6f07e97b..90a13b25 100644
--- a/cluster/validation.go
+++ b/cluster/validation.go
@@ -54,6 +54,11 @@ func (c *Cluster) ValidateCluster(ctx context.Context) error {
 		return err
 	}
 
+	// validate registry credential plugin
+	if err := validateRegistryAuthPlugin(c); err != nil {
+		return err
+	}
+
 	// validate services options
 	return validateServicesOptions(c)
 }
@@ -605,3 +610,19 @@ func validateCRIDockerdOption(c *Cluster) error {
 	}
 	return nil
 }
+
+func validateRegistryAuthPlugin(c *Cluster) error {
+	for _, pr := range c.PrivateRegistriesMap {
+		if len(pr.CredentialPlugin) != 0 {
+			if credPluginType, ok := pr.CredentialPlugin["type"]; ok {
+				switch credPluginType {
+				case "ecr":
+					logrus.Debugf("Plugin type %s is valid", credPluginType)
+				default:
+					return fmt.Errorf("invalid registry plugin helper provided for %s", pr.URL)
+				}
+			}
+		}
+	}
+	return nil
+}
diff --git a/docker/docker.go b/docker/docker.go
index 67671b9f..fcd2f967 100644
--- a/docker/docker.go
+++ b/docker/docker.go
@@ -13,6 +13,8 @@ import (
 	"strings"
 	"time"
 
+	"github.com/rancher/rke/util"
+
 	"github.com/coreos/go-semver/semver"
 	ref "github.com/docker/distribution/reference"
 	"github.com/docker/docker/api/types"
@@ -41,7 +43,8 @@ const (
 )
 
 type dockerConfig struct {
-	Auths map[string]authConfig `json:"auths,omitempty"`
+	Auths       map[string]authConfig `json:"auths,omitempty"`
+	CredHelpers map[string]string     `json:"credHelpers,omitempty"`
 }
 
 type authConfig types.AuthConfig
@@ -667,10 +670,28 @@ func tryRegistryAuth(pr v3.PrivateRegistry) types.RequestPrivilegeFunc {
 }
 
 func getRegistryAuth(pr v3.PrivateRegistry) (string, error) {
-	authConfig := types.AuthConfig{
-		Username: pr.User,
-		Password: pr.Password,
+	var authConfig types.AuthConfig
+	var err error
+	if len(pr.User) == 0 && len(pr.Password) == 0 && len(pr.CredentialPlugin) != 0 {
+		if regType, ok := pr.CredentialPlugin["type"]; ok {
+			switch regType {
+			case "ecr":
+				// generate ecr authConfig
+				authConfig, err = util.ECRCredentialPlugin(pr.CredentialPlugin, pr.URL)
+				if err != nil {
+					return "", err
+				}
+			default:
+				return "", fmt.Errorf("Unsupported Credential Plugin")
+			}
+		}
+	} else {
+		authConfig = types.AuthConfig{
+			Username: pr.User,
+			Password: pr.Password,
+		}
 	}
+
 	encodedJSON, err := json.Marshal(authConfig)
 	if err != nil {
 		return "", err
@@ -738,12 +759,20 @@ func isContainerEnvChanged(containerEnv, imageConfigEnv, dockerfileEnv []string)
 
 func GetKubeletDockerConfig(prsMap map[string]v3.PrivateRegistry) (string, error) {
 	auths := map[string]authConfig{}
-
+	credHelper := make(map[string]string)
 	for url, pr := range prsMap {
-		auth := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", pr.User, pr.Password)))
-		auths[url] = authConfig{Auth: auth}
+		if len(pr.CredentialPlugin) != 0 {
+			if credPluginType, ok := pr.CredentialPlugin["type"]; ok {
+				if credPluginType == "ecr" {
+					credHelper[pr.URL] = "ecr-login"
+				}
+			}
+		} else {
+			auth := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", pr.User, pr.Password)))
+			auths[url] = authConfig{Auth: auth}
+		}
 	}
-	cfg, err := json.Marshal(dockerConfig{auths})
+	cfg, err := json.Marshal(dockerConfig{auths, credHelper})
 	if err != nil {
 		return "", err
 	}
diff --git a/go.mod b/go.mod
index 70cf9359..75655837 100644
--- a/go.mod
+++ b/go.mod
@@ -11,6 +11,7 @@ require (
 	github.com/Masterminds/sprig/v3 v3.2.2
 	github.com/Microsoft/hcsshim v0.8.9 // indirect
 	github.com/apparentlymart/go-cidr v1.0.1
+	github.com/aws/aws-sdk-go v1.38.65
 	github.com/blang/semver v3.5.1+incompatible
 	github.com/containerd/containerd v1.4.1-0.20201117152358-0edc412565dc // indirect
 	github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe // indirect
diff --git a/go.sum b/go.sum
index 8ee5be6f..1021c292 100644
--- a/go.sum
+++ b/go.sum
@@ -71,6 +71,8 @@ github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmV
 github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
 github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
 github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
+github.com/aws/aws-sdk-go v1.38.65 h1:umGu5gjIOKxzhi34T0DIA1TWupUDjV2aAW5vK6154Gg=
+github.com/aws/aws-sdk-go v1.38.65/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
 github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@@ -363,6 +365,10 @@ github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA=
 github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
 github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
 github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
+github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
+github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
 github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
 github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
 github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
diff --git a/types/rke_types.go b/types/rke_types.go
index ab80b01a..fc059505 100644
--- a/types/rke_types.go
+++ b/types/rke_types.go
@@ -110,6 +110,8 @@ type PrivateRegistry struct {
 	Password string `yaml:"password" json:"password,omitempty" norman:"type=password"`
 	// Default registry
 	IsDefault bool `yaml:"is_default" json:"isDefault,omitempty"`
+	// CredentialPlugin
+	CredentialPlugin map[string]string `yaml:"credentialPlugin" json:"credentialPlugin,omitempty"`
 }
 
 type RKESystemImages struct {
diff --git a/types/zz_generated_deepcopy.go b/types/zz_generated_deepcopy.go
index 74b77b44..5a923a2a 100644
--- a/types/zz_generated_deepcopy.go
+++ b/types/zz_generated_deepcopy.go
@@ -1354,6 +1354,13 @@ func (in *PortCheck) DeepCopy() *PortCheck {
 // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
 func (in *PrivateRegistry) DeepCopyInto(out *PrivateRegistry) {
 	*out = *in
+	if in.CredentialPlugin != nil {
+		in, out := &in.CredentialPlugin, &out.CredentialPlugin
+		*out = make(map[string]string, len(*in))
+		for key, val := range *in {
+			(*out)[key] = val
+		}
+	}
 	return
 }
 
@@ -1624,7 +1631,9 @@ func (in *RancherKubernetesEngineConfig) DeepCopyInto(out *RancherKubernetesEngi
 	if in.PrivateRegistries != nil {
 		in, out := &in.PrivateRegistries, &out.PrivateRegistries
 		*out = make([]PrivateRegistry, len(*in))
-		copy(*out, *in)
+		for i := range *in {
+			(*in)[i].DeepCopyInto(&(*out)[i])
+		}
 	}
 	in.Ingress.DeepCopyInto(&out.Ingress)
 	in.CloudProvider.DeepCopyInto(&out.CloudProvider)
diff --git a/util/ecr.go b/util/ecr.go
new file mode 100644
index 00000000..e5465726
--- /dev/null
+++ b/util/ecr.go
@@ -0,0 +1,89 @@
+package util
+
+import (
+	"encoding/base64"
+	"fmt"
+	"regexp"
+	"strings"
+
+	"github.com/aws/aws-sdk-go/aws"
+	"github.com/aws/aws-sdk-go/aws/credentials"
+	"github.com/aws/aws-sdk-go/aws/session"
+	"github.com/aws/aws-sdk-go/service/ecr"
+	"github.com/docker/docker/api/types"
+)
+
+const proxyEndpointScheme = "https://"
+
+var ecrPattern = regexp.MustCompile(`(^[a-zA-Z0-9][a-zA-Z0-9-_]*)\.dkr\.ecr(\-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.amazonaws\.com(\.cn)?`)
+
+// ECRCredentialPlugin is a wrapper to generate ECR token using the AWS Credentials
+func ECRCredentialPlugin(plugin map[string]string, pr string) (authConfig types.AuthConfig, err error) {
+
+	if strings.HasPrefix(pr, proxyEndpointScheme) {
+		pr = strings.TrimPrefix(pr, proxyEndpointScheme)
+	}
+	matches := ecrPattern.FindStringSubmatch(pr)
+	if len(matches) == 0 {
+		return authConfig, fmt.Errorf("Not a valid ECR registry")
+	} else if len(matches) < 3 {
+		return authConfig, fmt.Errorf(pr + "is not a valid repository URI for Amazon Elastic Container Registry.")
+	}
+
+	config := &aws.Config{
+		Region: aws.String(matches[3]),
+	}
+
+	var sess *session.Session
+	awsAccessKeyID, accessKeyOK := plugin["aws_access_key_id"]
+	awsSecretAccessKey, secretKeyOK := plugin["aws_secret_access_key"]
+
+	// Use predefined keys and override env lookup if keys are present //
+	if accessKeyOK && secretKeyOK {
+		// if session token doesnt exist just pass empty string
+		awsSessionToken := plugin["aws_session_token"]
+		config.Credentials = credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, awsSessionToken)
+		sess, err = session.NewSession(config)
+	} else {
+		sess, err = session.NewSessionWithOptions(session.Options{
+			Config:            *config,
+			SharedConfigState: session.SharedConfigEnable,
+		})
+	}
+
+	if err != nil {
+		return authConfig, err
+	}
+
+	ecrClient := ecr.New(sess)
+
+	result, err := ecrClient.GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{})
+	if err != nil {
+		return authConfig, err
+	}
+	if len(result.AuthorizationData) == 0 {
+		return authConfig, fmt.Errorf("No authorization data returned")
+	}
+
+	authConfig, err = extractToken(*result.AuthorizationData[0].AuthorizationToken)
+	return authConfig, err
+}
+
+func extractToken(token string) (authConfig types.AuthConfig, err error) {
+	decodedToken, err := base64.StdEncoding.DecodeString(token)
+	if err != nil {
+		return authConfig, fmt.Errorf("Invalid token: %v", err)
+	}
+
+	parts := strings.SplitN(string(decodedToken), ":", 2)
+	if len(parts) < 2 {
+		return authConfig, fmt.Errorf("Invalid token: expected two parts, got %d", len(parts))
+	}
+
+	authConfig = types.AuthConfig{
+		Username: parts[0],
+		Password: parts[1],
+	}
+
+	return authConfig, nil
+}