diff --git a/cluster/plan.go b/cluster/plan.go index 02ae23c2..8dcd1883 100644 --- a/cluster/plan.go +++ b/cluster/plan.go @@ -3,9 +3,13 @@ package cluster import ( "context" "crypto/md5" + "crypto/sha256" b64 "encoding/base64" + "encoding/hex" "encoding/json" "fmt" + "hash" + "log" "net" "path" "strconv" @@ -216,7 +220,7 @@ func (c *Cluster) BuildKubeAPIProcess(host *hosts.Host, serviceOptions v3.Kubern CommandArgs["authentication-token-webhook-cache-ttl"] = c.Authentication.Webhook.CacheTimeout } if len(c.CloudProvider.Name) > 0 { - Env = append(Env, fmt.Sprintf("%s=%s", CloudConfigSumEnv, getStringChecksum(c.CloudConfigFile))) + Env = append(Env, fmt.Sprintf("%s=%s", CloudConfigSumEnv, getStringChecksum(c.CloudConfigFile, c.Version))) } if c.EncryptionConfig.EncryptionProviderFile != "" { CommandArgs[EncryptionProviderConfigArgument] = EncryptionProviderFilePath @@ -293,7 +297,7 @@ func (c *Cluster) BuildKubeAPIProcess(host *hosts.Host, serviceOptions v3.Kubern if err != nil { logrus.Warnf("Error while marshalling admission configuration: %v", err) } - Env = append(Env, fmt.Sprintf("%s=%s", AdmissionConfigSumEnv, getStringChecksum(string(bytes)))) + Env = append(Env, fmt.Sprintf("%s=%s", AdmissionConfigSumEnv, getStringChecksum(string(bytes), c.Version))) } if c.Services.KubeAPI.AuditLog != nil && c.Services.KubeAPI.AuditLog.Enabled { Binds = append(Binds, fmt.Sprintf("%s:/var/log/kube-audit", path.Join(host.PrefixPath, "/var/log/kube-audit"))) @@ -301,7 +305,7 @@ func (c *Cluster) BuildKubeAPIProcess(host *hosts.Host, serviceOptions v3.Kubern if err != nil { logrus.Warnf("Error while marshalling auditlog policy: %v", err) } - Env = append(Env, fmt.Sprintf("%s=%s", AuditLogConfigSumEnv, getStringChecksum(string(bytes)))) + Env = append(Env, fmt.Sprintf("%s=%s", AuditLogConfigSumEnv, getStringChecksum(string(bytes), c.Version))) } matchedRange, err := util.SemVerMatchRange(c.Version, util.SemVerK8sVersion122OrHigher) @@ -379,7 +383,7 @@ func (c *Cluster) BuildKubeControllerProcess(host *hosts.Host, serviceOptions v3 if len(c.CloudProvider.Name) > 0 { c.Services.KubeController.ExtraEnv = append( c.Services.KubeController.ExtraEnv, - fmt.Sprintf("%s=%s", CloudConfigSumEnv, getStringChecksum(c.CloudConfigFile))) + fmt.Sprintf("%s=%s", CloudConfigSumEnv, getStringChecksum(c.CloudConfigFile, c.Version))) } if serviceOptions.KubeController != nil { @@ -639,7 +643,7 @@ func (c *Cluster) BuildKubeletProcess(host *hosts.Host, serviceOptions v3.Kubern if len(c.CloudProvider.Name) > 0 { Env = append(Env, - fmt.Sprintf("%s=%s", CloudConfigSumEnv, getStringChecksum(c.CloudConfigFile))) + fmt.Sprintf("%s=%s", CloudConfigSumEnv, getStringChecksum(c.CloudConfigFile, c.Version))) } if len(c.PrivateRegistriesMap) > 0 { kubeletDockerConfig, _ := docker.GetKubeletDockerConfig(c.PrivateRegistriesMap) @@ -1274,9 +1278,25 @@ func (c *Cluster) getDefaultKubernetesServicesOptions(osType string) (v3.Kuberne return v3.KubernetesServicesOptions{}, fmt.Errorf("getDefaultKubernetesServicesOptions: No serviceOptions found for cluster version [%s] or cluster major version [%s]", c.Version, clusterMajorVersion) } -func getStringChecksum(config string) string { - configByteSum := md5.Sum([]byte(config)) - return fmt.Sprintf("%x", configByteSum) +func getStringChecksum(config string, version string) string { + greaterThan1316, err := util.SemVerMatchRange(version, util.SemVerK8sVersion1316OrHigher) + if err != nil { + logrus.Warnf("failed to check if version %q was greater than 1.31.6: %v, falling back to old behavior", version, err) + } + + var hasher hash.Hash + if greaterThan1316 { + hasher = sha256.New() + } else { + hasher = md5.New() + } + + _, err = hasher.Write([]byte(config)) + if err != nil { + log.Fatalf("failed to hash config: %v", err) + } + + return hex.EncodeToString(hasher.Sum(nil)) } func getUniqStringList(l []string) []string { diff --git a/cluster/plan_test.go b/cluster/plan_test.go index 5dcc3bc0..7735f93c 100644 --- a/cluster/plan_test.go +++ b/cluster/plan_test.go @@ -1,6 +1,9 @@ package cluster import ( + "crypto/md5" + "crypto/sha256" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -50,3 +53,57 @@ func Test_getUniqStringList(t *testing.T) { }) } } + +func Test_getStringChecksum(t *testing.T) { + tests := []struct { + name string + config string + version string + expected string + }{ + { + name: "version greater than 1.31.6, use sha256", + config: "test-config", + version: "v1.32.0-rancher0", + expected: fmt.Sprintf("%x", sha256.Sum256([]byte("test-config"))), + }, + { + name: "version exactly 1.31.6, use sha256", + config: "test-config", + version: "v1.31.6-rancher0", + expected: fmt.Sprintf("%x", sha256.Sum256([]byte("test-config"))), + }, + { + name: "version exactly 1.31.0, use md5", + config: "test-config", + version: "v1.31.0-rancher0", + expected: fmt.Sprintf("%x", md5.Sum([]byte("test-config"))), + }, + { + name: "version less than 1.31, use md5", + config: "test-config", + version: "v1.30.0-rancher0", + expected: fmt.Sprintf("%x", md5.Sum([]byte("test-config"))), + }, + { + name: "empty config", + config: "", + version: "v1.32.0-rancher0", + expected: fmt.Sprintf("%x", sha256.Sum256([]byte(""))), + }, + { + name: "empty version", + config: "test-config", + version: "", + expected: fmt.Sprintf("%x", md5.Sum([]byte("test-config"))), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := getStringChecksum(tt.config, tt.version) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/util/util.go b/util/util.go index 946f0f66..f3caeca2 100644 --- a/util/util.go +++ b/util/util.go @@ -17,8 +17,9 @@ import ( ) const ( - WorkerThreads = 50 - SemVerK8sVersion122OrHigher = ">=1.22.0-rancher0" + WorkerThreads = 50 + SemVerK8sVersion122OrHigher = ">=1.22.0-rancher0" + SemVerK8sVersion1316OrHigher = ">=1.31.6-rancher0" ) var ProxyEnvVars = [3]string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}