From 3d9bb9a128123e4c6956dbf09f8c5e1ca34df8b8 Mon Sep 17 00:00:00 2001 From: Avi Deitcher Date: Thu, 31 Jul 2025 19:36:10 +0300 Subject: [PATCH] add support for specifying additional certificates (#4153) Signed-off-by: Avi Deitcher --- src/cmd/linuxkit/cmd.go | 14 ++++ src/cmd/linuxkit/pkg.go | 8 ++ src/cmd/linuxkit/registry/remote.go | 121 +++++++++++++++++++++++++--- 3 files changed, 130 insertions(+), 13 deletions(-) diff --git a/src/cmd/linuxkit/cmd.go b/src/cmd/linuxkit/cmd.go index 85747a1ab..fdff24387 100644 --- a/src/cmd/linuxkit/cmd.go +++ b/src/cmd/linuxkit/cmd.go @@ -49,6 +49,7 @@ func newCmd() *cobra.Command { flagVerbose int flagVerboseName = "verbose" mirrorsRaw []string + certFiles []string ) cmd := &cobra.Command{ Use: "linuxkit", @@ -87,6 +88,18 @@ func newCmd() *cobra.Command { } } + for _, f := range certFiles { + if f == "" { + continue + } + cert, err := os.ReadFile(f) + if err != nil { + return fmt.Errorf("failed to read certificate file %q: %w", f, err) + } + // Add the certificate file to the registry + registry.AddCert(cert) + } + // Set up logging return util.SetupLogging(flagQuiet, flagVerbose, cmd.Flag(flagVerboseName).Changed) }, @@ -103,6 +116,7 @@ func newCmd() *cobra.Command { cmd.PersistentFlags().StringVar(&cacheDir, "cache", defaultLinuxkitCache(), fmt.Sprintf("Directory for caching and finding cached image, overrides env var %s", envVarCacheDir)) cmd.PersistentFlags().StringArrayVar(&mirrorsRaw, "mirror", nil, "Mirror to use for pulling images, format is =, e.g. docker.io=http://mymirror.io, or just http://mymirror.io for all not otherwise specified; must include protocol. Can be provided multiple times.") + cmd.PersistentFlags().StringArrayVar(&certFiles, "cert-file", nil, "Path to certificate files to use for pulling images, can be provided multiple times. Will augment system-provided certs.") cmd.PersistentFlags().BoolVarP(&flagQuiet, "quiet", "q", false, "Quiet execution") cmd.PersistentFlags().IntVarP(&flagVerbose, flagVerboseName, "v", 1, "Verbosity of logging: 0 = quiet, 1 = info, 2 = debug, 3 = trace. Default is info. Setting it explicitly will create structured logging lines.") diff --git a/src/cmd/linuxkit/pkg.go b/src/cmd/linuxkit/pkg.go index 462e8eafb..e79a62cf5 100644 --- a/src/cmd/linuxkit/pkg.go +++ b/src/cmd/linuxkit/pkg.go @@ -33,6 +33,14 @@ func pkgCmd() *cobra.Command { Short: "package building and pushing", Long: `Package building and pushing.`, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + if parent := cmd.Parent(); parent != nil { + if parent.PersistentPreRunE != nil { + if err := parent.PersistentPreRunE(parent, args); err != nil { + return err + } + } + } + pkglibConfig = pkglib.PkglibConfig{ BuildYML: buildYML, Hash: hash, diff --git a/src/cmd/linuxkit/registry/remote.go b/src/cmd/linuxkit/registry/remote.go index 3dda7d238..b5a08bbda 100644 --- a/src/cmd/linuxkit/registry/remote.go +++ b/src/cmd/linuxkit/registry/remote.go @@ -1,7 +1,10 @@ package registry import ( + "crypto/tls" + "crypto/x509" "fmt" + "net/http" "strings" "github.com/google/go-containerregistry/pkg/name" @@ -12,6 +15,9 @@ import ( // proxy is a map of registry names to proxy URLs. var proxy = make(map[string]string) +// certs is a slice of certificates to be used for secure connections. +var certs = make([][]byte, 0) + func SetProxy(registry, url string) { if url == "" { delete(proxy, registry) @@ -20,6 +26,10 @@ func SetProxy(registry, url string) { } } +func AddCert(cert []byte) { + certs = append(certs, cert) +} + // Remote implements the functions of // github.com/google/go-containerregistry/pkg/v1/remote, while possibly pre-configured for // items like proxies, mirrors, authentication, or other settings. @@ -40,8 +50,12 @@ func (r *Remote) Get(ref name.Reference, options ...remote.Option) (*remote.Desc if err != nil { return nil, fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return nil, fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.Get(ref, options...) + return remote.Get(ref, opts...) } func (r *Remote) Head(ref name.Reference, options ...remote.Option) (*v1.Descriptor, error) { @@ -50,12 +64,20 @@ func (r *Remote) Head(ref name.Reference, options ...remote.Option) (*v1.Descrip if err != nil { return nil, fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return nil, fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.Head(ref, options...) + return remote.Head(ref, opts...) } func (r *Remote) Tag(ref name.Tag, t remote.Taggable, options ...remote.Option) error { - return remote.Tag(ref, t, options...) + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } + return remote.Tag(ref, t, opts...) } func (r *Remote) Push(ref name.Reference, t remote.Taggable, options ...remote.Option) error { @@ -64,8 +86,12 @@ func (r *Remote) Push(ref name.Reference, t remote.Taggable, options ...remote.O if err != nil { return fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.Push(ref, t, options...) + return remote.Push(ref, t, opts...) } func (r *Remote) Put(ref name.Reference, t remote.Taggable, options ...remote.Option) error { @@ -74,8 +100,12 @@ func (r *Remote) Put(ref name.Reference, t remote.Taggable, options ...remote.Op if err != nil { return fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.Put(ref, t, options...) + return remote.Put(ref, t, opts...) } func (r *Remote) Write(ref name.Reference, img v1.Image, options ...remote.Option) error { @@ -84,8 +114,12 @@ func (r *Remote) Write(ref name.Reference, img v1.Image, options ...remote.Optio if err != nil { return fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.Write(ref, img, options...) + return remote.Write(ref, img, opts...) } func (r *Remote) WriteIndex(ref name.Reference, ii v1.ImageIndex, options ...remote.Option) error { @@ -94,8 +128,12 @@ func (r *Remote) WriteIndex(ref name.Reference, ii v1.ImageIndex, options ...rem if err != nil { return fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.WriteIndex(ref, ii, options...) + return remote.WriteIndex(ref, ii, opts...) } func (r *Remote) WriteLayer(repo name.Repository, layer v1.Layer, options ...remote.Option) error { @@ -104,8 +142,12 @@ func (r *Remote) WriteLayer(repo name.Repository, layer v1.Layer, options ...rem if err != nil { return fmt.Errorf("rewriting repository %q: %w", repo.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return fmt.Errorf("rewriting TLS transport for %q: %w", repo.Name(), err) + } - return remote.WriteLayer(repo, layer, options...) + return remote.WriteLayer(repo, layer, opts...) } func (r *Remote) List(repo name.Repository, options ...remote.Option) ([]string, error) { @@ -114,7 +156,12 @@ func (r *Remote) List(repo name.Repository, options ...remote.Option) ([]string, if err != nil { return nil, fmt.Errorf("rewriting repository %q: %w", repo.Name(), err) } - return remote.List(repo, options...) + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return nil, fmt.Errorf("rewriting TLS transport for %q: %w", repo.Name(), err) + } + + return remote.List(repo, opts...) } func (r *Remote) Layer(ref name.Digest, options ...remote.Option) (v1.Layer, error) { @@ -123,7 +170,11 @@ func (r *Remote) Layer(ref name.Digest, options ...remote.Option) (v1.Layer, err if err != nil { return nil, fmt.Errorf("rewriting digest %q: %w", ref.Name(), err) } - return remote.Layer(ref, options...) + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return nil, fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } + return remote.Layer(ref, opts...) } func (r *Remote) Index(ref name.Reference, options ...remote.Option) (v1.ImageIndex, error) { @@ -132,8 +183,12 @@ func (r *Remote) Index(ref name.Reference, options ...remote.Option) (v1.ImageIn if err != nil { return nil, fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return nil, fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.Index(ref, options...) + return remote.Index(ref, opts...) } func (r *Remote) Image(ref name.Reference, options ...remote.Option) (v1.Image, error) { @@ -142,8 +197,12 @@ func (r *Remote) Image(ref name.Reference, options ...remote.Option) (v1.Image, if err != nil { return nil, fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return nil, fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.Image(ref, options...) + return remote.Image(ref, opts...) } func (r *Remote) Delete(ref name.Reference, options ...remote.Option) error { @@ -152,8 +211,44 @@ func (r *Remote) Delete(ref name.Reference, options ...remote.Option) error { if err != nil { return fmt.Errorf("rewriting reference %q: %w", ref.Name(), err) } + opts, err := r.rewriteTLSTransport(options) + if err != nil { + return fmt.Errorf("rewriting TLS transport for %q: %w", ref.Name(), err) + } - return remote.Delete(ref, options...) + return remote.Delete(ref, opts...) +} + +func (r *Remote) rewriteTLSTransport(options []remote.Option) ([]remote.Option, error) { + // If there are no certs, return the options as is + if len(certs) == 0 { + return options, nil + } + + caCertPool, err := x509.SystemCertPool() + if err != nil || caCertPool == nil { + caCertPool = x509.NewCertPool() + } + for _, caCert := range certs { + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to append CA certificate") + } + } + + baseTransport := remote.DefaultTransport.(*http.Transport).Clone() + if baseTransport.TLSClientConfig == nil { + baseTransport.TLSClientConfig = &tls.Config{} + } + baseTransport.TLSClientConfig.RootCAs = caCertPool + + // Add the certificates to the options + newOptions := make([]remote.Option, 0, len(options)+1) + newOptions = append(newOptions, options...) + newOptions = append(newOptions, remote.WithTransport( + baseTransport, + )) + + return newOptions, nil } func (r *Remote) rewriteReference(ref name.Reference) (name.Reference, error) {