diff --git a/cmd/kubeadm/app/apis/kubeadm/fuzzer/fuzzer.go b/cmd/kubeadm/app/apis/kubeadm/fuzzer/fuzzer.go index 8ee7b4a7be5..c804505fe00 100644 --- a/cmd/kubeadm/app/apis/kubeadm/fuzzer/fuzzer.go +++ b/cmd/kubeadm/app/apis/kubeadm/fuzzer/fuzzer.go @@ -22,6 +22,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" runtimeserializer "k8s.io/apimachinery/pkg/runtime/serializer" + "k8s.io/utils/ptr" bootstraptokenv1 "k8s.io/kubernetes/cmd/kubeadm/app/apis/bootstraptoken/v1" "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm" @@ -62,6 +63,7 @@ func fuzzInitConfiguration(obj *kubeadm.InitConfiguration, c fuzz.Continue) { } obj.SkipPhases = nil obj.NodeRegistration.ImagePullPolicy = corev1.PullIfNotPresent + obj.NodeRegistration.ImagePullSerial = ptr.To(true) obj.Patches = nil obj.DryRun = false kubeadm.SetDefaultTimeouts(&obj.Timeouts) @@ -72,6 +74,7 @@ func fuzzNodeRegistration(obj *kubeadm.NodeRegistrationOptions, c fuzz.Continue) // Pinning values for fields that get defaults if fuzz value is empty string or nil (thus making the round trip test fail) obj.IgnorePreflightErrors = nil + obj.ImagePullSerial = ptr.To(true) } func fuzzClusterConfiguration(obj *kubeadm.ClusterConfiguration, c fuzz.Continue) { @@ -132,6 +135,7 @@ func fuzzJoinConfiguration(obj *kubeadm.JoinConfiguration, c fuzz.Continue) { } obj.SkipPhases = nil obj.NodeRegistration.ImagePullPolicy = corev1.PullIfNotPresent + obj.NodeRegistration.ImagePullSerial = ptr.To(true) obj.Patches = nil obj.DryRun = false kubeadm.SetDefaultTimeouts(&obj.Timeouts) diff --git a/cmd/kubeadm/app/apis/kubeadm/types.go b/cmd/kubeadm/app/apis/kubeadm/types.go index 0e8ce6d64b2..2e2b334c2f9 100644 --- a/cmd/kubeadm/app/apis/kubeadm/types.go +++ b/cmd/kubeadm/app/apis/kubeadm/types.go @@ -249,6 +249,9 @@ type NodeRegistrationOptions struct { // The value of this field must be one of "Always", "IfNotPresent" or "Never". // If this field is unset kubeadm will default it to "IfNotPresent", or pull the required images if not present on the host. ImagePullPolicy v1.PullPolicy `json:"imagePullPolicy,omitempty"` + + // ImagePullSerial specifies if image pulling performed by kubeadm must be done serially or in parallel. + ImagePullSerial *bool } // Networking contains elements describing cluster's networking configuration. diff --git a/cmd/kubeadm/app/apis/kubeadm/v1beta3/conversion.go b/cmd/kubeadm/app/apis/kubeadm/v1beta3/conversion.go index 384788cc3c0..6d1cd7d8f19 100644 --- a/cmd/kubeadm/app/apis/kubeadm/v1beta3/conversion.go +++ b/cmd/kubeadm/app/apis/kubeadm/v1beta3/conversion.go @@ -20,6 +20,7 @@ import ( "sort" "k8s.io/apimachinery/pkg/conversion" + "k8s.io/utils/ptr" "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm" ) @@ -98,6 +99,7 @@ func Convert_kubeadm_LocalEtcd_To_v1beta3_LocalEtcd(in *kubeadm.LocalEtcd, out * // Convert_v1beta3_NodeRegistrationOptions_To_kubeadm_NodeRegistrationOptions converts a public NodeRegistrationOptions to private NodeRegistrationOptions. func Convert_v1beta3_NodeRegistrationOptions_To_kubeadm_NodeRegistrationOptions(in *NodeRegistrationOptions, out *kubeadm.NodeRegistrationOptions, s conversion.Scope) error { out.KubeletExtraArgs = convertToArgs(in.KubeletExtraArgs) + out.ImagePullSerial = ptr.To(true) return autoConvert_v1beta3_NodeRegistrationOptions_To_kubeadm_NodeRegistrationOptions(in, out, s) } diff --git a/cmd/kubeadm/app/apis/kubeadm/v1beta3/zz_generated.conversion.go b/cmd/kubeadm/app/apis/kubeadm/v1beta3/zz_generated.conversion.go index 53937bbf05d..39c0bc82264 100644 --- a/cmd/kubeadm/app/apis/kubeadm/v1beta3/zz_generated.conversion.go +++ b/cmd/kubeadm/app/apis/kubeadm/v1beta3/zz_generated.conversion.go @@ -708,6 +708,7 @@ func autoConvert_kubeadm_NodeRegistrationOptions_To_v1beta3_NodeRegistrationOpti // WARNING: in.KubeletExtraArgs requires manual conversion: inconvertible types ([]k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm.Arg vs map[string]string) out.IgnorePreflightErrors = *(*[]string)(unsafe.Pointer(&in.IgnorePreflightErrors)) out.ImagePullPolicy = corev1.PullPolicy(in.ImagePullPolicy) + // WARNING: in.ImagePullSerial requires manual conversion: does not exist in peer-type return nil } diff --git a/cmd/kubeadm/app/apis/kubeadm/v1beta4/defaults.go b/cmd/kubeadm/app/apis/kubeadm/v1beta4/defaults.go index 9348f3acc67..529aa8cf948 100644 --- a/cmd/kubeadm/app/apis/kubeadm/v1beta4/defaults.go +++ b/cmd/kubeadm/app/apis/kubeadm/v1beta4/defaults.go @@ -22,6 +22,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/utils/ptr" bootstraptokenv1 "k8s.io/kubernetes/cmd/kubeadm/app/apis/bootstraptoken/v1" "k8s.io/kubernetes/cmd/kubeadm/app/constants" @@ -194,6 +195,9 @@ func SetDefaults_NodeRegistration(obj *NodeRegistrationOptions) { if len(obj.ImagePullPolicy) == 0 { obj.ImagePullPolicy = DefaultImagePullPolicy } + if obj.ImagePullSerial == nil { + obj.ImagePullSerial = ptr.To(true) + } } // SetDefaults_ResetConfiguration assigns default values for the ResetConfiguration object diff --git a/cmd/kubeadm/app/apis/kubeadm/v1beta4/doc.go b/cmd/kubeadm/app/apis/kubeadm/v1beta4/doc.go index 4f5d0135aef..257dee63c36 100644 --- a/cmd/kubeadm/app/apis/kubeadm/v1beta4/doc.go +++ b/cmd/kubeadm/app/apis/kubeadm/v1beta4/doc.go @@ -40,6 +40,8 @@ limitations under the License. // during cluster creation will set the same fields to `false`. // - Add a `Timeouts` structure to `InitConfiguration`, `JoinConfiguration` and `ResetConfiguration“ // that can be used to configure various timeouts. +// - Add the `NodeRegistration.ImagePullSerial` field in 'InitConfiguration` and `JoinConfiguration`, which +// can be used to control if kubeadm pulls images serially or in parallel. // // Migration from old kubeadm config versions // diff --git a/cmd/kubeadm/app/apis/kubeadm/v1beta4/types.go b/cmd/kubeadm/app/apis/kubeadm/v1beta4/types.go index 0eec1df807f..f5a74c993e2 100644 --- a/cmd/kubeadm/app/apis/kubeadm/v1beta4/types.go +++ b/cmd/kubeadm/app/apis/kubeadm/v1beta4/types.go @@ -263,6 +263,11 @@ type NodeRegistrationOptions struct { // If this field is unset kubeadm will default it to "IfNotPresent", or pull the required images if not present on the host. // +optional ImagePullPolicy corev1.PullPolicy `json:"imagePullPolicy,omitempty"` + + // ImagePullSerial specifies if image pulling performed by kubeadm must be done serially or in parallel. + // Default: true + // +optional + ImagePullSerial *bool `json:"imagePullSerial,omitempty"` } // Networking contains elements describing cluster's networking configuration diff --git a/cmd/kubeadm/app/apis/kubeadm/v1beta4/zz_generated.conversion.go b/cmd/kubeadm/app/apis/kubeadm/v1beta4/zz_generated.conversion.go index 1419357bcd7..4cff41d0938 100644 --- a/cmd/kubeadm/app/apis/kubeadm/v1beta4/zz_generated.conversion.go +++ b/cmd/kubeadm/app/apis/kubeadm/v1beta4/zz_generated.conversion.go @@ -810,6 +810,7 @@ func autoConvert_v1beta4_NodeRegistrationOptions_To_kubeadm_NodeRegistrationOpti out.KubeletExtraArgs = *(*[]kubeadm.Arg)(unsafe.Pointer(&in.KubeletExtraArgs)) out.IgnorePreflightErrors = *(*[]string)(unsafe.Pointer(&in.IgnorePreflightErrors)) out.ImagePullPolicy = v1.PullPolicy(in.ImagePullPolicy) + out.ImagePullSerial = (*bool)(unsafe.Pointer(in.ImagePullSerial)) return nil } @@ -825,6 +826,7 @@ func autoConvert_kubeadm_NodeRegistrationOptions_To_v1beta4_NodeRegistrationOpti out.KubeletExtraArgs = *(*[]Arg)(unsafe.Pointer(&in.KubeletExtraArgs)) out.IgnorePreflightErrors = *(*[]string)(unsafe.Pointer(&in.IgnorePreflightErrors)) out.ImagePullPolicy = v1.PullPolicy(in.ImagePullPolicy) + out.ImagePullSerial = (*bool)(unsafe.Pointer(in.ImagePullSerial)) return nil } diff --git a/cmd/kubeadm/app/apis/kubeadm/v1beta4/zz_generated.deepcopy.go b/cmd/kubeadm/app/apis/kubeadm/v1beta4/zz_generated.deepcopy.go index 07ba43306f2..55ec68abca4 100644 --- a/cmd/kubeadm/app/apis/kubeadm/v1beta4/zz_generated.deepcopy.go +++ b/cmd/kubeadm/app/apis/kubeadm/v1beta4/zz_generated.deepcopy.go @@ -518,6 +518,11 @@ func (in *NodeRegistrationOptions) DeepCopyInto(out *NodeRegistrationOptions) { *out = make([]string, len(*in)) copy(*out, *in) } + if in.ImagePullSerial != nil { + in, out := &in.ImagePullSerial, &out.ImagePullSerial + *out = new(bool) + **out = **in + } return } diff --git a/cmd/kubeadm/app/apis/kubeadm/validation/validation.go b/cmd/kubeadm/app/apis/kubeadm/validation/validation.go index 9923e5326a7..ac5871d3c2a 100644 --- a/cmd/kubeadm/app/apis/kubeadm/validation/validation.go +++ b/cmd/kubeadm/app/apis/kubeadm/validation/validation.go @@ -29,6 +29,7 @@ import ( "github.com/pkg/errors" "github.com/spf13/pflag" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation" "k8s.io/apimachinery/pkg/util/validation/field" @@ -135,6 +136,7 @@ func ValidateNodeRegistrationOptions(nro *kubeadm.NodeRegistrationOptions, fldPa } allErrs = append(allErrs, ValidateSocketPath(nro.CRISocket, fldPath.Child("criSocket"))...) allErrs = append(allErrs, ValidateExtraArgs(nro.KubeletExtraArgs, fldPath.Child("kubeletExtraArgs"))...) + allErrs = append(allErrs, ValidateImagePullPolicy(nro.ImagePullPolicy, fldPath.Child("imagePullPolicy"))...) // TODO: Maybe validate .Taints as well in the future using something like validateNodeTaints() in pkg/apis/core/validation return allErrs } @@ -739,3 +741,15 @@ func ValidateUnmountFlags(flags []string, fldPath *field.Path) field.ErrorList { return allErrs } + +// ValidateImagePullPolicy validates if the user specified pull policy is correct +func ValidateImagePullPolicy(policy corev1.PullPolicy, fldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + switch policy { + case "", corev1.PullAlways, corev1.PullIfNotPresent, corev1.PullNever: + return allErrs + default: + allErrs = append(allErrs, field.Invalid(fldPath, policy, "invalid pull policy")) + return allErrs + } +} diff --git a/cmd/kubeadm/app/apis/kubeadm/validation/validation_test.go b/cmd/kubeadm/app/apis/kubeadm/validation/validation_test.go index 02b25392217..31fe675a7a8 100644 --- a/cmd/kubeadm/app/apis/kubeadm/validation/validation_test.go +++ b/cmd/kubeadm/app/apis/kubeadm/validation/validation_test.go @@ -24,6 +24,7 @@ import ( "github.com/spf13/pflag" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation/field" @@ -1503,3 +1504,34 @@ func TestValidateUnmountFlags(t *testing.T) { } } } + +func TestPullPolicy(t *testing.T) { + var tests = []struct { + name string + policy string + expectedErrors int + }{ + { + name: "empty policy causes no errors", // gets defaulted + policy: "", + expectedErrors: 0, + }, + { + name: "invalid policy", + policy: "foo", + expectedErrors: 1, + }, + { + name: "valid policy", + policy: "IfNotPresent", + expectedErrors: 0, + }, + } + + for _, tc := range tests { + actual := ValidateImagePullPolicy(corev1.PullPolicy(tc.policy), nil) + if len(actual) != tc.expectedErrors { + t.Errorf("case %q:\n\t expected errors: %v\n\t got: %v\n\t errors: %v", tc.name, tc.expectedErrors, len(actual), actual) + } + } +} diff --git a/cmd/kubeadm/app/apis/kubeadm/zz_generated.deepcopy.go b/cmd/kubeadm/app/apis/kubeadm/zz_generated.deepcopy.go index b5d96201fcc..c9b8749d8af 100644 --- a/cmd/kubeadm/app/apis/kubeadm/zz_generated.deepcopy.go +++ b/cmd/kubeadm/app/apis/kubeadm/zz_generated.deepcopy.go @@ -558,6 +558,11 @@ func (in *NodeRegistrationOptions) DeepCopyInto(out *NodeRegistrationOptions) { *out = make([]string, len(*in)) copy(*out, *in) } + if in.ImagePullSerial != nil { + in, out := &in.ImagePullSerial, &out.ImagePullSerial + *out = new(bool) + **out = **in + } return } diff --git a/cmd/kubeadm/app/cmd/init_test.go b/cmd/kubeadm/app/cmd/init_test.go index 9fb9e8ce3ee..88a85a29833 100644 --- a/cmd/kubeadm/app/cmd/init_test.go +++ b/cmd/kubeadm/app/cmd/init_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/assert" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/utils/ptr" bootstraptokenv1 "k8s.io/kubernetes/cmd/kubeadm/app/apis/bootstraptoken/v1" kubeadmapi "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm" @@ -134,6 +135,7 @@ func TestNewInitData(t *testing.T) { CRISocket: expectedCRISocket, IgnorePreflightErrors: []string{"c", "d"}, ImagePullPolicy: "IfNotPresent", + ImagePullSerial: ptr.To(true), }, LocalAPIEndpoint: kubeadmapi.APIEndpoint{ AdvertiseAddress: "1.2.3.4", diff --git a/cmd/kubeadm/app/cmd/join_test.go b/cmd/kubeadm/app/cmd/join_test.go index 4f556f71521..31eb0a6c350 100644 --- a/cmd/kubeadm/app/cmd/join_test.go +++ b/cmd/kubeadm/app/cmd/join_test.go @@ -31,6 +31,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/klog/v2" + "k8s.io/utils/ptr" kubeadmapi "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm" kubeadmapiv1 "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3" @@ -221,6 +222,7 @@ func TestNewJoinData(t *testing.T) { CRISocket: expectedCRISocket, IgnorePreflightErrors: []string{"c", "d"}, ImagePullPolicy: "IfNotPresent", + ImagePullSerial: ptr.To(true), Taints: []v1.Taint{{Key: "node-role.kubernetes.io/control-plane", Effect: "NoSchedule"}}, }, CACertPath: kubeadmapiv1.DefaultCACertPath, diff --git a/cmd/kubeadm/app/preflight/checks.go b/cmd/kubeadm/app/preflight/checks.go index 70265a84cf2..bb4aa2319b0 100644 --- a/cmd/kubeadm/app/preflight/checks.go +++ b/cmd/kubeadm/app/preflight/checks.go @@ -815,6 +815,7 @@ type ImagePullCheck struct { imageList []string sandboxImage string imagePullPolicy v1.PullPolicy + imagePullSerial bool } // Name returns the label for ImagePullCheck @@ -824,22 +825,39 @@ func (ImagePullCheck) Name() string { // Check pulls images required by kubeadm. This is a mutating check func (ipc ImagePullCheck) Check() (warnings, errorList []error) { + // Handle unsupported image pull policy and policy Never. policy := ipc.imagePullPolicy - klog.V(1).Infof("using image pull policy: %s", policy) - for _, image := range ipc.imageList { - if image == ipc.sandboxImage { - criSandboxImage, err := ipc.runtime.SandboxImage() - if err != nil { - klog.V(4).Infof("failed to detect the sandbox image for local container runtime, %v", err) - } else if criSandboxImage != ipc.sandboxImage { - klog.Warningf("detected that the sandbox image %q of the container runtime is inconsistent with that used by kubeadm. It is recommended that using %q as the CRI sandbox image.", - criSandboxImage, ipc.sandboxImage) - } + switch policy { + case v1.PullAlways, v1.PullIfNotPresent: + klog.V(1).Infof("using image pull policy: %s", policy) + case v1.PullNever: + klog.V(1).Infof("skipping the pull of all images due to policy: %s", policy) + return warnings, errorList + default: + errorList = append(errorList, errors.Errorf("unsupported pull policy %q", policy)) + return warnings, errorList + } + + // Handle CRI sandbox image warnings. + criSandboxImage, err := ipc.runtime.SandboxImage() + if err != nil { + klog.V(4).Infof("failed to detect the sandbox image for local container runtime, %v", err) + } else if criSandboxImage != ipc.sandboxImage { + klog.Warningf("detected that the sandbox image %q of the container runtime is inconsistent with that used by kubeadm."+ + "It is recommended to use %q as the CRI sandbox image.", criSandboxImage, ipc.sandboxImage) + } + + // Perform parallel pulls. + if !ipc.imagePullSerial { + if err := ipc.runtime.PullImagesInParallel(ipc.imageList, policy == v1.PullIfNotPresent); err != nil { + errorList = append(errorList, err) } + return warnings, errorList + } + + // Perform serial pulls. + for _, image := range ipc.imageList { switch policy { - case v1.PullNever: - klog.V(1).Infof("skipping pull of image: %s", image) - continue case v1.PullIfNotPresent: ret, err := ipc.runtime.ImageExists(image) if ret && err == nil { @@ -853,14 +871,11 @@ func (ipc ImagePullCheck) Check() (warnings, errorList []error) { case v1.PullAlways: klog.V(1).Infof("pulling: %s", image) if err := ipc.runtime.PullImage(image); err != nil { - errorList = append(errorList, errors.Wrapf(err, "failed to pull image %s", image)) + errorList = append(errorList, errors.WithMessagef(err, "failed to pull image %s", image)) } - default: - // If the policy is unknown return early with an error - errorList = append(errorList, errors.Errorf("unsupported pull policy %q", policy)) - return warnings, errorList } } + return warnings, errorList } @@ -1096,12 +1111,18 @@ func RunPullImagesCheck(execer utilsexec.Interface, cfg *kubeadmapi.InitConfigur return &Error{Msg: err.Error()} } + serialPull := true + if cfg.NodeRegistration.ImagePullSerial != nil { + serialPull = *cfg.NodeRegistration.ImagePullSerial + } + checks := []Checker{ ImagePullCheck{ runtime: containerRuntime, imageList: images.GetControlPlaneImages(&cfg.ClusterConfiguration), sandboxImage: images.GetPauseImage(&cfg.ClusterConfiguration), imagePullPolicy: cfg.NodeRegistration.ImagePullPolicy, + imagePullSerial: serialPull, }, } return RunChecks(checks, os.Stderr, ignorePreflightErrors) diff --git a/cmd/kubeadm/app/preflight/checks_test.go b/cmd/kubeadm/app/preflight/checks_test.go index 28a493ea9d0..8ccf303c08c 100644 --- a/cmd/kubeadm/app/preflight/checks_test.go +++ b/cmd/kubeadm/app/preflight/checks_test.go @@ -866,9 +866,11 @@ func TestImagePullCheck(t *testing.T) { }, CombinedOutputScript: []fakeexec.FakeAction{ // Test case1: pull only img3 + func() ([]byte, []byte, error) { return []byte("pause"), nil, nil }, func() ([]byte, []byte, error) { return nil, nil, nil }, // Test case 2: fail to pull image2 and image3 // If the pull fails, it will be retried 5 times (see PullImageRetry in constants/constants.go) + func() ([]byte, []byte, error) { return []byte("pause"), nil, nil }, func() ([]byte, []byte, error) { return nil, nil, nil }, func() ([]byte, []byte, error) { return []byte("error"), nil, &fakeexec.FakeExitError{Status: 1} }, func() ([]byte, []byte, error) { return []byte("error"), nil, &fakeexec.FakeExitError{Status: 1} }, @@ -903,6 +905,8 @@ func TestImagePullCheck(t *testing.T) { func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(&fcmd, cmd, args...) }, + func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(&fcmd, cmd, args...) }, + func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(&fcmd, cmd, args...) }, }, LookPathFunc: func(cmd string) (string, error) { return "/usr/bin/crictl", nil }, } @@ -914,8 +918,10 @@ func TestImagePullCheck(t *testing.T) { check := ImagePullCheck{ runtime: containerRuntime, + sandboxImage: "pause", imageList: []string{"img1", "img2", "img3"}, imagePullPolicy: corev1.PullIfNotPresent, + imagePullSerial: true, } warnings, errors := check.Check() if len(warnings) != 0 { @@ -936,8 +942,10 @@ func TestImagePullCheck(t *testing.T) { // Test with unknown policy check = ImagePullCheck{ runtime: containerRuntime, + sandboxImage: "pause", imageList: []string{"img1", "img2", "img3"}, imagePullPolicy: "", + imagePullSerial: true, } _, errors = check.Check() if len(errors) != 1 { diff --git a/cmd/kubeadm/app/util/runtime/runtime.go b/cmd/kubeadm/app/util/runtime/runtime.go index 6767d44ae86..1036fa4c60d 100644 --- a/cmd/kubeadm/app/util/runtime/runtime.go +++ b/cmd/kubeadm/app/util/runtime/runtime.go @@ -43,6 +43,7 @@ type ContainerRuntime interface { ListKubeContainers() ([]string, error) RemoveContainers(containers []string) error PullImage(image string) error + PullImagesInParallel(images []string, ifNotPresent bool) error ImageExists(image string) (bool, error) SandboxImage() (string, error) } @@ -139,6 +140,53 @@ func (runtime *CRIRuntime) PullImage(image string) error { return errors.Wrapf(err, "output: %s, error", out) } +// PullImagesInParallel pulls a list of images in parallel +func (runtime *CRIRuntime) PullImagesInParallel(images []string, ifNotPresent bool) error { + errs := pullImagesInParallelImpl(images, ifNotPresent, runtime.ImageExists, runtime.PullImage) + return errorsutil.NewAggregate(errs) +} + +func pullImagesInParallelImpl(images []string, ifNotPresent bool, + imageExistsFunc func(string) (bool, error), pullImageFunc func(string) error) []error { + + var errs []error + errChan := make(chan error, len(images)) + + klog.V(1).Info("pulling all images in parallel") + for _, img := range images { + image := img + go func() { + if ifNotPresent { + exists, err := imageExistsFunc(image) + if err != nil { + errChan <- errors.WithMessagef(err, "failed to check if image %s exists", image) + return + } + if exists { + klog.V(1).Infof("image exists: %s", image) + errChan <- nil + return + } + } + err := pullImageFunc(image) + if err != nil { + err = errors.WithMessagef(err, "failed to pull image %s", image) + } else { + klog.V(1).Infof("done pulling: %s", image) + } + errChan <- err + }() + } + + for i := 0; i < len(images); i++ { + if err := <-errChan; err != nil { + errs = append(errs, err) + } + } + + return errs +} + // ImageExists checks to see if the image exists on the system func (runtime *CRIRuntime) ImageExists(image string) (bool, error) { err := runtime.crictl("inspecti", image).Run() diff --git a/cmd/kubeadm/app/util/runtime/runtime_test.go b/cmd/kubeadm/app/util/runtime/runtime_test.go index fc2fcb48ce3..d1990923414 100644 --- a/cmd/kubeadm/app/util/runtime/runtime_test.go +++ b/cmd/kubeadm/app/util/runtime/runtime_test.go @@ -461,3 +461,78 @@ func TestDetectCRISocketImpl(t *testing.T) { }) } } + +func TestPullImagesInParallelImpl(t *testing.T) { + testError := errors.New("error") + + tests := []struct { + name string + images []string + ifNotPresent bool + imageExistsFunc func(string) (bool, error) + pullImageFunc func(string) error + expectedErrors int + }{ + { + name: "all images exist, no errors", + images: []string{"foo", "bar", "baz"}, + ifNotPresent: true, + imageExistsFunc: func(string) (bool, error) { + return true, nil + }, + pullImageFunc: nil, + expectedErrors: 0, + }, + { + name: "cannot check if one image exists due to error", + images: []string{"foo", "bar", "baz"}, + ifNotPresent: true, + imageExistsFunc: func(image string) (bool, error) { + if image == "baz" { + return false, testError + } + return true, nil + }, + pullImageFunc: nil, + expectedErrors: 1, + }, + { + name: "cannot pull two images", + images: []string{"foo", "bar", "baz"}, + ifNotPresent: true, + imageExistsFunc: func(string) (bool, error) { + return false, nil + }, + pullImageFunc: func(image string) error { + if image == "foo" { + return nil + } + return testError + }, + expectedErrors: 2, + }, + { + name: "pull all images", + images: []string{"foo", "bar", "baz"}, + ifNotPresent: true, + imageExistsFunc: func(string) (bool, error) { + return false, nil + }, + pullImageFunc: func(string) error { + return nil + }, + expectedErrors: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actual := pullImagesInParallelImpl(tc.images, tc.ifNotPresent, + tc.imageExistsFunc, tc.pullImageFunc) + if len(actual) != tc.expectedErrors { + t.Fatalf("expected non-nil errors: %v, got: %v, full list of errors: %v", + tc.expectedErrors, len(actual), actual) + } + }) + } +}