diff --git a/src/runtime/virtcontainers/factory/factory.go b/src/runtime/virtcontainers/factory/factory.go index 7461a5158e..70aa05d9bc 100644 --- a/src/runtime/virtcontainers/factory/factory.go +++ b/src/runtime/virtcontainers/factory/factory.go @@ -7,17 +7,7 @@ package factory import ( "context" - "fmt" - - "github.com/kata-containers/kata-containers/src/runtime/pkg/katautils/katatrace" - pb "github.com/kata-containers/kata-containers/src/runtime/protocols/cache" vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" - "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/base" - "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/cache" - "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/direct" - "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/grpccache" - "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/template" - "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/utils" "github.com/sirupsen/logrus" ) @@ -43,56 +33,6 @@ type Config struct { VMCache bool } -type factory struct { - base base.FactoryBase -} - -// NewFactory returns a working factory. -func NewFactory(ctx context.Context, config Config, fetchOnly bool) (vc.Factory, error) { - span, _ := katatrace.Trace(ctx, nil, "NewFactory", factoryTracingTags) - defer span.End() - - err := config.VMConfig.Valid() - if err != nil { - return nil, err - } - - if fetchOnly && config.Cache > 0 { - return nil, fmt.Errorf("cache factory does not support fetch") - } - - var b base.FactoryBase - if config.VMCache && config.Cache == 0 { - // For VMCache client - b, err = grpccache.New(ctx, config.VMCacheEndpoint) - if err != nil { - return nil, err - } - } else { - if config.Template { - if fetchOnly { - b, err = template.Fetch(config.VMConfig, config.TemplatePath) - if err != nil { - return nil, err - } - } else { - b, err = template.New(ctx, config.VMConfig, config.TemplatePath) - if err != nil { - return nil, err - } - } - } else { - b = direct.New(ctx, config.VMConfig) - } - - if config.Cache > 0 { - b = cache.New(ctx, config.Cache, b) - } - } - - return &factory{b}, nil -} - // SetLogger sets the logger for the factory. func SetLogger(ctx context.Context, logger logrus.FieldLogger) { fields := logrus.Fields{ @@ -105,135 +45,3 @@ func SetLogger(ctx context.Context, logger logrus.FieldLogger) { func (f *factory) log() *logrus.Entry { return factoryLogger.WithField("subsystem", "factory") } - -func resetHypervisorConfig(config *vc.VMConfig) { - config.HypervisorConfig.NumVCPUs = 0 - config.HypervisorConfig.MemorySize = 0 - config.HypervisorConfig.BootToBeTemplate = false - config.HypervisorConfig.BootFromTemplate = false - config.HypervisorConfig.MemoryPath = "" - config.HypervisorConfig.DevicesStatePath = "" - config.HypervisorConfig.SharedPath = "" - config.HypervisorConfig.VMStorePath = "" - config.HypervisorConfig.RunStorePath = "" -} - -// It's important that baseConfig and newConfig are passed by value! -func checkVMConfig(baseConfig, newConfig vc.VMConfig) error { - if baseConfig.HypervisorType != newConfig.HypervisorType { - return fmt.Errorf("hypervisor type does not match: %s vs. %s", baseConfig.HypervisorType, newConfig.HypervisorType) - } - - // check hypervisor config details - resetHypervisorConfig(&baseConfig) - resetHypervisorConfig(&newConfig) - - if !utils.DeepCompare(baseConfig, newConfig) { - return fmt.Errorf("hypervisor config does not match, base: %+v. new: %+v", baseConfig, newConfig) - } - - return nil -} - -func (f *factory) checkConfig(config vc.VMConfig) error { - baseConfig := f.base.Config() - - return checkVMConfig(baseConfig, config) -} - -// GetVM returns a working blank VM created by the factory. -func (f *factory) GetVM(ctx context.Context, config vc.VMConfig) (*vc.VM, error) { - span, ctx := katatrace.Trace(ctx, f.log(), "GetVM", factoryTracingTags) - defer span.End() - - hypervisorConfig := config.HypervisorConfig - if err := config.Valid(); err != nil { - f.log().WithError(err).Error("invalid hypervisor config") - return nil, err - } - - err := f.checkConfig(config) - if err != nil { - f.log().WithError(err).Info("fallback to direct factory vm") - return direct.New(ctx, config).GetBaseVM(ctx, config) - } - - f.log().Info("get base VM") - vm, err := f.base.GetBaseVM(ctx, config) - if err != nil { - f.log().WithError(err).Error("failed to get base VM") - return nil, err - } - - // cleanup upon error - defer func() { - if err != nil { - f.log().WithError(err).Error("clean up vm") - vm.Stop(ctx) - } - }() - - err = vm.Resume(ctx) - if err != nil { - return nil, err - } - - // reseed RNG so that shared memory VMs do not generate same random numbers. - err = vm.ReseedRNG(ctx) - if err != nil { - return nil, err - } - - // sync guest time since we might have paused it for a long time. - err = vm.SyncTime(ctx) - if err != nil { - return nil, err - } - - online := false - baseConfig := f.base.Config().HypervisorConfig - if baseConfig.NumVCPUs < hypervisorConfig.NumVCPUs { - err = vm.AddCPUs(ctx, hypervisorConfig.NumVCPUs-baseConfig.NumVCPUs) - if err != nil { - return nil, err - } - online = true - } - - if baseConfig.MemorySize < hypervisorConfig.MemorySize { - err = vm.AddMemory(ctx, hypervisorConfig.MemorySize-baseConfig.MemorySize) - if err != nil { - return nil, err - } - online = true - } - - if online { - err = vm.OnlineCPUMemory(ctx) - if err != nil { - return nil, err - } - } - - return vm, nil -} - -// Config returns base factory config. -func (f *factory) Config() vc.VMConfig { - return f.base.Config() -} - -// GetVMStatus returns the status of the paused VM created by the base factory. -func (f *factory) GetVMStatus() []*pb.GrpcVMStatus { - return f.base.GetVMStatus() -} - -// GetBaseVM returns a paused VM created by the base factory. -func (f *factory) GetBaseVM(ctx context.Context, config vc.VMConfig) (*vc.VM, error) { - return f.base.GetBaseVM(ctx, config) -} - -// CloseFactory closes the factory. -func (f *factory) CloseFactory(ctx context.Context) { - f.base.CloseFactory(ctx) -} diff --git a/src/runtime/virtcontainers/factory/factory_darwin.go b/src/runtime/virtcontainers/factory/factory_darwin.go new file mode 100644 index 0000000000..bdebca9e4e --- /dev/null +++ b/src/runtime/virtcontainers/factory/factory_darwin.go @@ -0,0 +1,43 @@ +// Copyright (c) 2022 Apple Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// + +package factory + +import ( + "context" + + pb "github.com/kata-containers/kata-containers/src/runtime/protocols/cache" + vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" + "github.com/pkg/errors" +) + +var unsupportedFactory error = errors.New("VM factory is unsupported on Darwin") + +type factory struct { +} + +func NewFactory(ctx context.Context, config Config, fetchOnly bool) (vc.Factory, error) { + return &factory{}, unsupportedFactory +} + +func (f *factory) Config() vc.VMConfig { + return vc.VMConfig{} +} + +func (f *factory) GetVMStatus() []*pb.GrpcVMStatus { + return nil +} + +func (f *factory) GetVM(ctx context.Context, config vc.VMConfig) (*vc.VM, error) { + return nil, unsupportedFactory +} + +func (f *factory) GetBaseVM(ctx context.Context, config vc.VMConfig) (*vc.VM, error) { + return nil, unsupportedFactory +} + +func (f *factory) CloseFactory(ctx context.Context) { + return +} diff --git a/src/runtime/virtcontainers/factory/factory_linux.go b/src/runtime/virtcontainers/factory/factory_linux.go new file mode 100644 index 0000000000..86a384d121 --- /dev/null +++ b/src/runtime/virtcontainers/factory/factory_linux.go @@ -0,0 +1,203 @@ +// Copyright (c) 2018 HyperHQ Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// + +package factory + +import ( + "context" + "fmt" + + "github.com/kata-containers/kata-containers/src/runtime/pkg/katautils/katatrace" + pb "github.com/kata-containers/kata-containers/src/runtime/protocols/cache" + vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" + "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/base" + "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/cache" + "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/direct" + "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/grpccache" + "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/template" + "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/utils" +) + +type factory struct { + base base.FactoryBase +} + +// NewFactory returns a working factory. +func NewFactory(ctx context.Context, config Config, fetchOnly bool) (vc.Factory, error) { + span, _ := katatrace.Trace(ctx, nil, "NewFactory", factoryTracingTags) + defer span.End() + + err := config.VMConfig.Valid() + if err != nil { + return nil, err + } + + if fetchOnly && config.Cache > 0 { + return nil, fmt.Errorf("cache factory does not support fetch") + } + + var b base.FactoryBase + if config.VMCache && config.Cache == 0 { + // For VMCache client + b, err = grpccache.New(ctx, config.VMCacheEndpoint) + if err != nil { + return nil, err + } + } else { + if config.Template { + if fetchOnly { + b, err = template.Fetch(config.VMConfig, config.TemplatePath) + if err != nil { + return nil, err + } + } else { + b, err = template.New(ctx, config.VMConfig, config.TemplatePath) + if err != nil { + return nil, err + } + } + } else { + b = direct.New(ctx, config.VMConfig) + } + + if config.Cache > 0 { + b = cache.New(ctx, config.Cache, b) + } + } + + return &factory{b}, nil +} + +func resetHypervisorConfig(config *vc.VMConfig) { + config.HypervisorConfig.NumVCPUs = 0 + config.HypervisorConfig.MemorySize = 0 + config.HypervisorConfig.BootToBeTemplate = false + config.HypervisorConfig.BootFromTemplate = false + config.HypervisorConfig.MemoryPath = "" + config.HypervisorConfig.DevicesStatePath = "" + config.HypervisorConfig.SharedPath = "" + config.HypervisorConfig.VMStorePath = "" + config.HypervisorConfig.RunStorePath = "" +} + +// It's important that baseConfig and newConfig are passed by value! +func checkVMConfig(baseConfig, newConfig vc.VMConfig) error { + if baseConfig.HypervisorType != newConfig.HypervisorType { + return fmt.Errorf("hypervisor type does not match: %s vs. %s", baseConfig.HypervisorType, newConfig.HypervisorType) + } + + // check hypervisor config details + resetHypervisorConfig(&baseConfig) + resetHypervisorConfig(&newConfig) + + if !utils.DeepCompare(baseConfig, newConfig) { + return fmt.Errorf("hypervisor config does not match, base: %+v. new: %+v", baseConfig, newConfig) + } + + return nil +} + +func (f *factory) checkConfig(config vc.VMConfig) error { + baseConfig := f.base.Config() + + return checkVMConfig(baseConfig, config) +} + +// GetVM returns a working blank VM created by the factory. +func (f *factory) GetVM(ctx context.Context, config vc.VMConfig) (*vc.VM, error) { + span, ctx := katatrace.Trace(ctx, f.log(), "GetVM", factoryTracingTags) + defer span.End() + + hypervisorConfig := config.HypervisorConfig + if err := config.Valid(); err != nil { + f.log().WithError(err).Error("invalid hypervisor config") + return nil, err + } + + err := f.checkConfig(config) + if err != nil { + f.log().WithError(err).Info("fallback to direct factory vm") + return direct.New(ctx, config).GetBaseVM(ctx, config) + } + + f.log().Info("get base VM") + vm, err := f.base.GetBaseVM(ctx, config) + if err != nil { + f.log().WithError(err).Error("failed to get base VM") + return nil, err + } + + // cleanup upon error + defer func() { + if err != nil { + f.log().WithError(err).Error("clean up vm") + vm.Stop(ctx) + } + }() + + err = vm.Resume(ctx) + if err != nil { + return nil, err + } + + // reseed RNG so that shared memory VMs do not generate same random numbers. + err = vm.ReseedRNG(ctx) + if err != nil { + return nil, err + } + + // sync guest time since we might have paused it for a long time. + err = vm.SyncTime(ctx) + if err != nil { + return nil, err + } + + online := false + baseConfig := f.base.Config().HypervisorConfig + if baseConfig.NumVCPUs < hypervisorConfig.NumVCPUs { + err = vm.AddCPUs(ctx, hypervisorConfig.NumVCPUs-baseConfig.NumVCPUs) + if err != nil { + return nil, err + } + online = true + } + + if baseConfig.MemorySize < hypervisorConfig.MemorySize { + err = vm.AddMemory(ctx, hypervisorConfig.MemorySize-baseConfig.MemorySize) + if err != nil { + return nil, err + } + online = true + } + + if online { + err = vm.OnlineCPUMemory(ctx) + if err != nil { + return nil, err + } + } + + return vm, nil +} + +// Config returns base factory config. +func (f *factory) Config() vc.VMConfig { + return f.base.Config() +} + +// GetVMStatus returns the status of the paused VM created by the base factory. +func (f *factory) GetVMStatus() []*pb.GrpcVMStatus { + return f.base.GetVMStatus() +} + +// GetBaseVM returns a paused VM created by the base factory. +func (f *factory) GetBaseVM(ctx context.Context, config vc.VMConfig) (*vc.VM, error) { + return f.base.GetBaseVM(ctx, config) +} + +// CloseFactory closes the factory. +func (f *factory) CloseFactory(ctx context.Context) { + f.base.CloseFactory(ctx) +} diff --git a/src/runtime/virtcontainers/factory/template/template.go b/src/runtime/virtcontainers/factory/template/template.go index 0b3e96dae8..afb45c34e0 100644 --- a/src/runtime/virtcontainers/factory/template/template.go +++ b/src/runtime/virtcontainers/factory/template/template.go @@ -1,3 +1,4 @@ +// // Copyright (c) 2018 HyperHQ Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -8,178 +9,12 @@ package template import ( "context" - "fmt" - "os" - "syscall" - "time" - pb "github.com/kata-containers/kata-containers/src/runtime/protocols/cache" - vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" - "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/base" "github.com/sirupsen/logrus" ) -type template struct { - statePath string - config vc.VMConfig -} - -var templateWaitForAgent = 2 * time.Second var templateLog = logrus.WithField("source", "virtcontainers/factory/template") -// Fetch finds and returns a pre-built template factory. -// TODO: save template metadata and fetch from storage. -func Fetch(config vc.VMConfig, templatePath string) (base.FactoryBase, error) { - t := &template{templatePath, config} - - err := t.checkTemplateVM() - if err != nil { - return nil, err - } - - return t, nil -} - -// New creates a new VM template factory. -func New(ctx context.Context, config vc.VMConfig, templatePath string) (base.FactoryBase, error) { - t := &template{templatePath, config} - - err := t.checkTemplateVM() - if err == nil { - return nil, fmt.Errorf("There is already a VM template in %s", templatePath) - } - - err = t.prepareTemplateFiles() - if err != nil { - return nil, err - } - defer func() { - if err != nil { - t.close() - } - }() - - err = t.createTemplateVM(ctx) - if err != nil { - return nil, err - } - - return t, nil -} - -// Config returns template factory's configuration. -func (t *template) Config() vc.VMConfig { - return t.config -} - -// GetBaseVM creates a new paused VM from the template VM. -func (t *template) GetBaseVM(ctx context.Context, config vc.VMConfig) (*vc.VM, error) { - return t.createFromTemplateVM(ctx, config) -} - -// CloseFactory cleans up the template VM. -func (t *template) CloseFactory(ctx context.Context) { - t.close() -} - -// GetVMStatus is not supported -func (t *template) GetVMStatus() []*pb.GrpcVMStatus { - panic("ERROR: package template does not support GetVMStatus") -} - -func (t *template) close() { - if err := syscall.Unmount(t.statePath, syscall.MNT_DETACH); err != nil { - t.Logger().WithError(err).Errorf("failed to unmount %s", t.statePath) - } - - if err := os.RemoveAll(t.statePath); err != nil { - t.Logger().WithError(err).Errorf("failed to remove %s", t.statePath) - } -} - -func (t *template) prepareTemplateFiles() error { - // create and mount tmpfs for the shared memory file - err := os.MkdirAll(t.statePath, 0700) - if err != nil { - return err - } - flags := uintptr(syscall.MS_NOSUID | syscall.MS_NODEV) - opts := fmt.Sprintf("size=%dM", t.config.HypervisorConfig.MemorySize+templateDeviceStateSize) - if err = syscall.Mount("tmpfs", t.statePath, "tmpfs", flags, opts); err != nil { - t.close() - return err - } - f, err := os.Create(t.statePath + "/memory") - if err != nil { - t.close() - return err - } - f.Close() - - return nil -} - -func (t *template) createTemplateVM(ctx context.Context) error { - // create the template vm - config := t.config - config.HypervisorConfig.BootToBeTemplate = true - config.HypervisorConfig.BootFromTemplate = false - config.HypervisorConfig.MemoryPath = t.statePath + "/memory" - config.HypervisorConfig.DevicesStatePath = t.statePath + "/state" - - vm, err := vc.NewVM(ctx, config) - if err != nil { - return err - } - defer vm.Stop(ctx) - - if err = vm.Disconnect(ctx); err != nil { - return err - } - - // Sleep a bit to let the agent grpc server clean up - // When we close connection to the agent, it needs sometime to cleanup - // and restart listening on the communication( serial or vsock) port. - // That time can be saved if we sleep a bit to wait for the agent to - // come around and start listening again. The sleep is only done when - // creating new vm templates and saves time for every new vm that are - // created from template, so it worth the invest. - time.Sleep(templateWaitForAgent) - - if err = vm.Pause(ctx); err != nil { - return err - } - - if err = vm.Save(); err != nil { - return err - } - - return nil -} - -func (t *template) createFromTemplateVM(ctx context.Context, c vc.VMConfig) (*vc.VM, error) { - config := t.config - config.HypervisorConfig.BootToBeTemplate = false - config.HypervisorConfig.BootFromTemplate = true - config.HypervisorConfig.MemoryPath = t.statePath + "/memory" - config.HypervisorConfig.DevicesStatePath = t.statePath + "/state" - config.HypervisorConfig.SharedPath = c.HypervisorConfig.SharedPath - config.HypervisorConfig.VMStorePath = c.HypervisorConfig.VMStorePath - config.HypervisorConfig.RunStorePath = c.HypervisorConfig.RunStorePath - - return vc.NewVM(ctx, config) -} - -func (t *template) checkTemplateVM() error { - _, err := os.Stat(t.statePath + "/memory") - if err != nil { - return err - } - - _, err = os.Stat(t.statePath + "/state") - return err -} - // Logger returns a logrus logger appropriate for logging template messages func (t *template) Logger() *logrus.Entry { return templateLog.WithFields(logrus.Fields{ diff --git a/src/runtime/virtcontainers/factory/template/template_darwin.go b/src/runtime/virtcontainers/factory/template/template_darwin.go new file mode 100644 index 0000000000..bd92898bf3 --- /dev/null +++ b/src/runtime/virtcontainers/factory/template/template_darwin.go @@ -0,0 +1,10 @@ +// +// Copyright (c) 2022 Apple, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// template implements base vm factory with vm templating. + +package template + +type template struct{} diff --git a/src/runtime/virtcontainers/factory/template/template_linux.go b/src/runtime/virtcontainers/factory/template/template_linux.go new file mode 100644 index 0000000000..d48ce5c50b --- /dev/null +++ b/src/runtime/virtcontainers/factory/template/template_linux.go @@ -0,0 +1,180 @@ +// +// Copyright (c) 2018 HyperHQ Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// template implements base vm factory with vm templating. + +package template + +import ( + "context" + "fmt" + "os" + "syscall" + "time" + + pb "github.com/kata-containers/kata-containers/src/runtime/protocols/cache" + vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" + "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/factory/base" +) + +type template struct { + statePath string + config vc.VMConfig +} + +var templateWaitForAgent = 2 * time.Second + +// Fetch finds and returns a pre-built template factory. +// TODO: save template metadata and fetch from storage. +func Fetch(config vc.VMConfig, templatePath string) (base.FactoryBase, error) { + t := &template{templatePath, config} + + err := t.checkTemplateVM() + if err != nil { + return nil, err + } + + return t, nil +} + +// New creates a new VM template factory. +func New(ctx context.Context, config vc.VMConfig, templatePath string) (base.FactoryBase, error) { + t := &template{templatePath, config} + + err := t.checkTemplateVM() + if err == nil { + return nil, fmt.Errorf("There is already a VM template in %s", templatePath) + } + + err = t.prepareTemplateFiles() + if err != nil { + return nil, err + } + defer func() { + if err != nil { + t.close() + } + }() + + err = t.createTemplateVM(ctx) + if err != nil { + return nil, err + } + + return t, nil +} + +// Config returns template factory's configuration. +func (t *template) Config() vc.VMConfig { + return t.config +} + +// GetBaseVM creates a new paused VM from the template VM. +func (t *template) GetBaseVM(ctx context.Context, config vc.VMConfig) (*vc.VM, error) { + return t.createFromTemplateVM(ctx, config) +} + +// CloseFactory cleans up the template VM. +func (t *template) CloseFactory(ctx context.Context) { + t.close() +} + +// GetVMStatus is not supported +func (t *template) GetVMStatus() []*pb.GrpcVMStatus { + panic("ERROR: package template does not support GetVMStatus") +} + +func (t *template) close() { + if err := syscall.Unmount(t.statePath, syscall.MNT_DETACH); err != nil { + t.Logger().WithError(err).Errorf("failed to unmount %s", t.statePath) + } + + if err := os.RemoveAll(t.statePath); err != nil { + t.Logger().WithError(err).Errorf("failed to remove %s", t.statePath) + } +} + +func (t *template) prepareTemplateFiles() error { + // create and mount tmpfs for the shared memory file + err := os.MkdirAll(t.statePath, 0700) + if err != nil { + return err + } + flags := uintptr(syscall.MS_NOSUID | syscall.MS_NODEV) + opts := fmt.Sprintf("size=%dM", t.config.HypervisorConfig.MemorySize+templateDeviceStateSize) + if err = syscall.Mount("tmpfs", t.statePath, "tmpfs", flags, opts); err != nil { + t.close() + return err + } + f, err := os.Create(t.statePath + "/memory") + if err != nil { + t.close() + return err + } + f.Close() + + return nil +} + +func (t *template) createTemplateVM(ctx context.Context) error { + // create the template vm + config := t.config + config.HypervisorConfig.BootToBeTemplate = true + config.HypervisorConfig.BootFromTemplate = false + config.HypervisorConfig.MemoryPath = t.statePath + "/memory" + config.HypervisorConfig.DevicesStatePath = t.statePath + "/state" + + vm, err := vc.NewVM(ctx, config) + if err != nil { + return err + } + defer vm.Stop(ctx) + + if err = vm.Disconnect(ctx); err != nil { + return err + } + + // Sleep a bit to let the agent grpc server clean up + // When we close connection to the agent, it needs sometime to cleanup + // and restart listening on the communication( serial or vsock) port. + // That time can be saved if we sleep a bit to wait for the agent to + // come around and start listening again. The sleep is only done when + // creating new vm templates and saves time for every new vm that are + // created from template, so it worth the invest. + time.Sleep(templateWaitForAgent) + + if err = vm.Pause(ctx); err != nil { + return err + } + + if err = vm.Save(); err != nil { + return err + } + + return nil +} + +func (t *template) createFromTemplateVM(ctx context.Context, c vc.VMConfig) (*vc.VM, error) { + config := t.config + config.HypervisorConfig.BootToBeTemplate = false + config.HypervisorConfig.BootFromTemplate = true + config.HypervisorConfig.MemoryPath = t.statePath + "/memory" + config.HypervisorConfig.DevicesStatePath = t.statePath + "/state" + config.HypervisorConfig.SharedPath = c.HypervisorConfig.SharedPath + config.HypervisorConfig.VMStorePath = c.HypervisorConfig.VMStorePath + config.HypervisorConfig.RunStorePath = c.HypervisorConfig.RunStorePath + + return vc.NewVM(ctx, config) +} + +func (t *template) checkTemplateVM() error { + _, err := os.Stat(t.statePath + "/memory") + if err != nil { + return err + } + + _, err = os.Stat(t.statePath + "/state") + return err +} diff --git a/src/runtime/virtcontainers/mount.go b/src/runtime/virtcontainers/mount.go index 5e75826199..243c13f330 100644 --- a/src/runtime/virtcontainers/mount.go +++ b/src/runtime/virtcontainers/mount.go @@ -6,29 +6,22 @@ package virtcontainers import ( - "context" "fmt" "os" "path/filepath" "strings" "syscall" - merr "github.com/hashicorp/go-multierror" volume "github.com/kata-containers/kata-containers/src/runtime/pkg/direct-volume" - "github.com/kata-containers/kata-containers/src/runtime/pkg/katautils/katatrace" "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/utils" "github.com/pkg/errors" "github.com/sirupsen/logrus" - otelLabel "go.opentelemetry.io/otel/attribute" ) // DefaultShmSize is the default shm size to be used in case host // IPC is used. const DefaultShmSize = 65536 * 1024 -// Sadly golang/sys doesn't have UmountNoFollow although it's there since Linux 2.6.34 -const UmountNoFollow = 0x8 - const ( rootfsDir = "rootfs" lowerDir = "lowerdir" @@ -50,13 +43,6 @@ func mountLogger() *logrus.Entry { return virtLog.WithField("subsystem", "mount") } -var propagationTypes = map[string]uintptr{ - "shared": syscall.MS_SHARED, - "private": syscall.MS_PRIVATE, - "slave": syscall.MS_SLAVE, - "ubind": syscall.MS_UNBINDABLE, -} - func isSystemMount(m string) bool { for _, p := range systemMountPrefixes { if m == p || strings.HasPrefix(m, p+"/") { @@ -248,83 +234,6 @@ func evalMountPath(source, destination string) (string, string, error) { return absSource, destination, nil } -// bindMount bind mounts a source in to a destination. This will -// do some bookkeeping: -// * evaluate all symlinks -// * ensure the source exists -// * recursively create the destination -// pgtypes stands for propagation types, which are shared, private, slave, and ubind. -func bindMount(ctx context.Context, source, destination string, readonly bool, pgtypes string) error { - span, _ := katatrace.Trace(ctx, nil, "bindMount", mountTracingTags) - defer span.End() - span.SetAttributes(otelLabel.String("source", source), otelLabel.String("destination", destination)) - - absSource, destination, err := evalMountPath(source, destination) - if err != nil { - return err - } - span.SetAttributes(otelLabel.String("source_after_eval", absSource)) - - if err := syscall.Mount(absSource, destination, "bind", syscall.MS_BIND, ""); err != nil { - return fmt.Errorf("Could not bind mount %v to %v: %v", absSource, destination, err) - } - - if pgtype, exist := propagationTypes[pgtypes]; exist { - if err := syscall.Mount("none", destination, "", pgtype, ""); err != nil { - return fmt.Errorf("Could not make mount point %v %s: %v", destination, pgtypes, err) - } - } else { - return fmt.Errorf("Wrong propagation type %s", pgtypes) - } - - // For readonly bind mounts, we need to remount with the readonly flag. - // This is needed as only very recent versions of libmount/util-linux support "bind,ro" - if readonly { - return syscall.Mount(absSource, destination, "bind", uintptr(syscall.MS_BIND|syscall.MS_REMOUNT|syscall.MS_RDONLY), "") - } - - return nil -} - -// An existing mount may be remounted by specifying `MS_REMOUNT` in -// mountflags. -// This allows you to change the mountflags of an existing mount. -// The mountflags should match the values used in the original mount() call, -// except for those parameters that you are trying to change. -func remount(ctx context.Context, mountflags uintptr, src string) error { - span, _ := katatrace.Trace(ctx, nil, "remount", mountTracingTags) - defer span.End() - span.SetAttributes(otelLabel.String("source", src)) - - absSrc, err := filepath.EvalSymlinks(src) - if err != nil { - return fmt.Errorf("Could not resolve symlink for %s", src) - } - span.SetAttributes(otelLabel.String("source_after_eval", absSrc)) - - if err := syscall.Mount(absSrc, absSrc, "", syscall.MS_REMOUNT|mountflags, ""); err != nil { - return fmt.Errorf("remount %s failed: %v", absSrc, err) - } - - return nil -} - -// remount a mount point as readonly -func remountRo(ctx context.Context, src string) error { - return remount(ctx, syscall.MS_BIND|syscall.MS_RDONLY, src) -} - -// bindMountContainerRootfs bind mounts a container rootfs into a 9pfs shared -// directory between the guest and the host. -func bindMountContainerRootfs(ctx context.Context, shareDir, cid, cRootFs string, readonly bool) error { - span, _ := katatrace.Trace(ctx, nil, "bindMountContainerRootfs", mountTracingTags) - defer span.End() - - rootfsDest := filepath.Join(shareDir, cid, rootfsDir) - - return bindMount(ctx, cRootFs, rootfsDest, readonly, "private") -} - // Mount describes a container mount. // nolint: govet type Mount struct { @@ -372,96 +281,6 @@ func isSymlink(path string) bool { return stat.Mode()&os.ModeSymlink != 0 } -func bindUnmountContainerShareDir(ctx context.Context, sharedDir, cID, target string) error { - destDir := filepath.Join(sharedDir, cID, target) - if isSymlink(filepath.Join(sharedDir, cID)) || isSymlink(destDir) { - mountLogger().WithField("container", cID).Warnf("container dir is a symlink, malicious guest?") - return nil - } - - err := syscall.Unmount(destDir, syscall.MNT_DETACH|UmountNoFollow) - if err == syscall.ENOENT { - mountLogger().WithError(err).WithField("share-dir", destDir).Warn() - return nil - } - if err := syscall.Rmdir(destDir); err != nil { - mountLogger().WithError(err).WithField("share-dir", destDir).Warn("Could not remove container share dir") - } - - return err -} - -func bindUnmountContainerRootfs(ctx context.Context, sharedDir, cID string) error { - span, _ := katatrace.Trace(ctx, nil, "bindUnmountContainerRootfs", mountTracingTags) - defer span.End() - span.SetAttributes(otelLabel.String("shared-dir", sharedDir), otelLabel.String("container-id", cID)) - return bindUnmountContainerShareDir(ctx, sharedDir, cID, rootfsDir) -} - -func bindUnmountContainerSnapshotDir(ctx context.Context, sharedDir, cID string) error { - span, _ := katatrace.Trace(ctx, nil, "bindUnmountContainerSnapshotDir", mountTracingTags) - defer span.End() - span.SetAttributes(otelLabel.String("shared-dir", sharedDir), otelLabel.String("container-id", cID)) - return bindUnmountContainerShareDir(ctx, sharedDir, cID, snapshotDir) -} - -func getVirtiofsDaemonForNydus(sandbox *Sandbox) (VirtiofsDaemon, error) { - var virtiofsDaemon VirtiofsDaemon - switch sandbox.GetHypervisorType() { - case string(QemuHypervisor): - virtiofsDaemon = sandbox.hypervisor.(*qemu).virtiofsDaemon - case string(ClhHypervisor): - virtiofsDaemon = sandbox.hypervisor.(*cloudHypervisor).virtiofsDaemon - default: - return nil, errNydusdNotSupport - } - return virtiofsDaemon, nil -} - -func nydusContainerCleanup(ctx context.Context, sharedDir string, c *Container) error { - sandbox := c.sandbox - virtiofsDaemon, err := getVirtiofsDaemonForNydus(sandbox) - if err != nil { - return err - } - if err := virtiofsDaemon.Umount(rafsMountPath(c.id)); err != nil { - return errors.Wrap(err, "umount rafs failed") - } - if err := bindUnmountContainerSnapshotDir(ctx, sharedDir, c.id); err != nil { - return errors.Wrap(err, "umount snapshotdir err") - } - destDir := filepath.Join(sharedDir, c.id, c.rootfsSuffix) - if err := syscall.Rmdir(destDir); err != nil { - return errors.Wrap(err, "remove container rootfs err") - } - return nil -} - -func bindUnmountAllRootfs(ctx context.Context, sharedDir string, sandbox *Sandbox) error { - span, ctx := katatrace.Trace(ctx, nil, "bindUnmountAllRootfs", mountTracingTags) - defer span.End() - span.SetAttributes(otelLabel.String("shared-dir", sharedDir), otelLabel.String("sandbox-id", sandbox.id)) - - var errors *merr.Error - for _, c := range sandbox.containers { - if isSymlink(filepath.Join(sharedDir, c.id)) { - mountLogger().WithField("container", c.id).Warnf("container dir is a symlink, malicious guest?") - continue - } - c.unmountHostMounts(ctx) - if c.state.Fstype == "" { - // even if error found, don't break out of loop until all mounts attempted - // to be unmounted, and collect all errors - if c.rootFs.Type == NydusRootFSType { - errors = merr.Append(errors, nydusContainerCleanup(ctx, sharedDir, c)) - } else { - errors = merr.Append(errors, bindUnmountContainerRootfs(ctx, sharedDir, c.id)) - } - } - } - return errors.ErrorOrNil() -} - const ( dockerVolumePrefix = "/var/lib/docker/volumes" dockerVolumeSuffix = "_data" diff --git a/src/runtime/virtcontainers/mount_darwin.go b/src/runtime/virtcontainers/mount_darwin.go new file mode 100644 index 0000000000..90af2a3aa3 --- /dev/null +++ b/src/runtime/virtcontainers/mount_darwin.go @@ -0,0 +1,12 @@ +// Copyright (c) 2023 Apple Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// + +package virtcontainers + +import "context" + +func nydusContainerCleanup(ctx context.Context, sharedDir string, c *Container) error { + return nil +} diff --git a/src/runtime/virtcontainers/mount_linux.go b/src/runtime/virtcontainers/mount_linux.go new file mode 100644 index 0000000000..be76a93a69 --- /dev/null +++ b/src/runtime/virtcontainers/mount_linux.go @@ -0,0 +1,195 @@ +// Copyright (c) 2017 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +package virtcontainers + +import ( + "context" + "fmt" + "path/filepath" + "syscall" + + merr "github.com/hashicorp/go-multierror" + "github.com/kata-containers/kata-containers/src/runtime/pkg/katautils/katatrace" + "github.com/pkg/errors" + otelLabel "go.opentelemetry.io/otel/attribute" +) + +// Sadly golang/sys doesn't have UmountNoFollow although it's there since Linux 2.6.34 +const UmountNoFollow = 0x8 + +var propagationTypes = map[string]uintptr{ + "shared": syscall.MS_SHARED, + "private": syscall.MS_PRIVATE, + "slave": syscall.MS_SLAVE, + "ubind": syscall.MS_UNBINDABLE, +} + +// bindMount bind mounts a source in to a destination. This will +// do some bookkeeping: +// * evaluate all symlinks +// * ensure the source exists +// * recursively create the destination +// pgtypes stands for propagation types, which are shared, private, slave, and ubind. +func bindMount(ctx context.Context, source, destination string, readonly bool, pgtypes string) error { + span, _ := katatrace.Trace(ctx, nil, "bindMount", mountTracingTags) + defer span.End() + span.SetAttributes(otelLabel.String("source", source), otelLabel.String("destination", destination)) + + absSource, destination, err := evalMountPath(source, destination) + if err != nil { + return err + } + span.SetAttributes(otelLabel.String("source_after_eval", absSource)) + + if err := syscall.Mount(absSource, destination, "bind", syscall.MS_BIND, ""); err != nil { + return fmt.Errorf("Could not bind mount %v to %v: %v", absSource, destination, err) + } + + if pgtype, exist := propagationTypes[pgtypes]; exist { + if err := syscall.Mount("none", destination, "", pgtype, ""); err != nil { + return fmt.Errorf("Could not make mount point %v %s: %v", destination, pgtypes, err) + } + } else { + return fmt.Errorf("Wrong propagation type %s", pgtypes) + } + + // For readonly bind mounts, we need to remount with the readonly flag. + // This is needed as only very recent versions of libmount/util-linux support "bind,ro" + if readonly { + return syscall.Mount(absSource, destination, "bind", uintptr(syscall.MS_BIND|syscall.MS_REMOUNT|syscall.MS_RDONLY), "") + } + + return nil +} + +// An existing mount may be remounted by specifying `MS_REMOUNT` in +// mountflags. +// This allows you to change the mountflags of an existing mount. +// The mountflags should match the values used in the original mount() call, +// except for those parameters that you are trying to change. +func remount(ctx context.Context, mountflags uintptr, src string) error { + span, _ := katatrace.Trace(ctx, nil, "remount", mountTracingTags) + defer span.End() + span.SetAttributes(otelLabel.String("source", src)) + + absSrc, err := filepath.EvalSymlinks(src) + if err != nil { + return fmt.Errorf("Could not resolve symlink for %s", src) + } + span.SetAttributes(otelLabel.String("source_after_eval", absSrc)) + + if err := syscall.Mount(absSrc, absSrc, "", syscall.MS_REMOUNT|mountflags, ""); err != nil { + return fmt.Errorf("remount %s failed: %v", absSrc, err) + } + + return nil +} + +// remount a mount point as readonly +func remountRo(ctx context.Context, src string) error { + return remount(ctx, syscall.MS_BIND|syscall.MS_RDONLY, src) +} + +// bindMountContainerRootfs bind mounts a container rootfs into a 9pfs shared +// directory between the guest and the host. +func bindMountContainerRootfs(ctx context.Context, shareDir, cid, cRootFs string, readonly bool) error { + span, _ := katatrace.Trace(ctx, nil, "bindMountContainerRootfs", mountTracingTags) + defer span.End() + + rootfsDest := filepath.Join(shareDir, cid, rootfsDir) + + return bindMount(ctx, cRootFs, rootfsDest, readonly, "private") +} + +func bindUnmountContainerShareDir(ctx context.Context, sharedDir, cID, target string) error { + destDir := filepath.Join(sharedDir, cID, target) + if isSymlink(filepath.Join(sharedDir, cID)) || isSymlink(destDir) { + mountLogger().WithField("container", cID).Warnf("container dir is a symlink, malicious guest?") + return nil + } + + err := syscall.Unmount(destDir, syscall.MNT_DETACH|UmountNoFollow) + if err == syscall.ENOENT { + mountLogger().WithError(err).WithField("share-dir", destDir).Warn() + return nil + } + if err := syscall.Rmdir(destDir); err != nil { + mountLogger().WithError(err).WithField("share-dir", destDir).Warn("Could not remove container share dir") + } + + return err +} + +func bindUnmountContainerRootfs(ctx context.Context, sharedDir, cID string) error { + span, _ := katatrace.Trace(ctx, nil, "bindUnmountContainerRootfs", mountTracingTags) + defer span.End() + span.SetAttributes(otelLabel.String("shared-dir", sharedDir), otelLabel.String("container-id", cID)) + return bindUnmountContainerShareDir(ctx, sharedDir, cID, rootfsDir) +} + +func bindUnmountContainerSnapshotDir(ctx context.Context, sharedDir, cID string) error { + span, _ := katatrace.Trace(ctx, nil, "bindUnmountContainerSnapshotDir", mountTracingTags) + defer span.End() + span.SetAttributes(otelLabel.String("shared-dir", sharedDir), otelLabel.String("container-id", cID)) + return bindUnmountContainerShareDir(ctx, sharedDir, cID, snapshotDir) +} + +func getVirtiofsDaemonForNydus(sandbox *Sandbox) (VirtiofsDaemon, error) { + var virtiofsDaemon VirtiofsDaemon + switch sandbox.GetHypervisorType() { + case string(QemuHypervisor): + virtiofsDaemon = sandbox.hypervisor.(*qemu).virtiofsDaemon + case string(ClhHypervisor): + virtiofsDaemon = sandbox.hypervisor.(*cloudHypervisor).virtiofsDaemon + default: + return nil, errNydusdNotSupport + } + return virtiofsDaemon, nil +} + +func nydusContainerCleanup(ctx context.Context, sharedDir string, c *Container) error { + sandbox := c.sandbox + virtiofsDaemon, err := getVirtiofsDaemonForNydus(sandbox) + if err != nil { + return err + } + if err := virtiofsDaemon.Umount(rafsMountPath(c.id)); err != nil { + return errors.Wrap(err, "umount rafs failed") + } + if err := bindUnmountContainerSnapshotDir(ctx, sharedDir, c.id); err != nil { + return errors.Wrap(err, "umount snapshotdir err") + } + destDir := filepath.Join(sharedDir, c.id, c.rootfsSuffix) + if err := syscall.Rmdir(destDir); err != nil { + return errors.Wrap(err, "remove container rootfs err") + } + return nil +} + +func bindUnmountAllRootfs(ctx context.Context, sharedDir string, sandbox *Sandbox) error { + span, ctx := katatrace.Trace(ctx, nil, "bindUnmountAllRootfs", mountTracingTags) + defer span.End() + span.SetAttributes(otelLabel.String("shared-dir", sharedDir), otelLabel.String("sandbox-id", sandbox.id)) + + var errors *merr.Error + for _, c := range sandbox.containers { + if isSymlink(filepath.Join(sharedDir, c.id)) { + mountLogger().WithField("container", c.id).Warnf("container dir is a symlink, malicious guest?") + continue + } + c.unmountHostMounts(ctx) + if c.state.Fstype == "" { + // even if error found, don't break out of loop until all mounts attempted + // to be unmounted, and collect all errors + if c.rootFs.Type == NydusRootFSType { + errors = merr.Append(errors, nydusContainerCleanup(ctx, sharedDir, c)) + } else { + errors = merr.Append(errors, bindUnmountContainerRootfs(ctx, sharedDir, c.id)) + } + } + } + return errors.ErrorOrNil() +} diff --git a/src/runtime/virtcontainers/mount_linux_test.go b/src/runtime/virtcontainers/mount_linux_test.go new file mode 100644 index 0000000000..a34f7c28f3 --- /dev/null +++ b/src/runtime/virtcontainers/mount_linux_test.go @@ -0,0 +1,321 @@ +// Copyright (c) 2017 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +package virtcontainers + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + + ktu "github.com/kata-containers/kata-containers/src/runtime/pkg/katatestutils" +) + +func TestIsEphemeralStorage(t *testing.T) { + assert := assert.New(t) + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(ktu.TestDisabledNeedRoot) + } + + dir, err := os.MkdirTemp(testDir, "foo") + assert.NoError(err) + defer os.RemoveAll(dir) + + sampleEphePath := filepath.Join(dir, K8sEmptyDir, "tmp-volume") + err = os.MkdirAll(sampleEphePath, testDirMode) + assert.Nil(err) + + err = syscall.Mount("tmpfs", sampleEphePath, "tmpfs", 0, "") + assert.NoError(err) + defer syscall.Unmount(sampleEphePath, 0) + + isEphe := IsEphemeralStorage(sampleEphePath) + assert.True(isEphe) + + isHostEmptyDir := Isk8sHostEmptyDir(sampleEphePath) + assert.False(isHostEmptyDir) + + sampleEphePath = "/var/lib/kubelet/pods/366c3a75-4869-11e8-b479-507b9ddd5ce4/volumes/cache-volume" + isEphe = IsEphemeralStorage(sampleEphePath) + assert.False(isEphe) + + isHostEmptyDir = Isk8sHostEmptyDir(sampleEphePath) + assert.False(isHostEmptyDir) +} + +func TestGetDeviceForPathBindMount(t *testing.T) { + assert := assert.New(t) + + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(ktu.TestDisabledNeedRoot) + } + + source := filepath.Join(testDir, "testDeviceDirSrc") + dest := filepath.Join(testDir, "testDeviceDirDest") + syscall.Unmount(dest, 0) + os.Remove(source) + os.Remove(dest) + + err := os.MkdirAll(source, mountPerm) + assert.NoError(err) + + defer os.Remove(source) + + err = os.MkdirAll(dest, mountPerm) + assert.NoError(err) + + defer os.Remove(dest) + + err = bindMount(context.Background(), source, dest, false, "private") + assert.NoError(err) + + defer syscall.Unmount(dest, syscall.MNT_DETACH) + + destFile := filepath.Join(dest, "test") + _, err = os.Create(destFile) + assert.NoError(err) + + defer os.Remove(destFile) + + sourceDev, _ := getDeviceForPath(source) + destDev, _ := getDeviceForPath(destFile) + + assert.Equal(sourceDev, destDev) +} + +func TestBindMountInvalidSourceSymlink(t *testing.T) { + source := filepath.Join(testDir, "fooFile") + os.Remove(source) + + err := bindMount(context.Background(), source, "", false, "private") + assert.Error(t, err) +} + +func TestBindMountFailingMount(t *testing.T) { + source := filepath.Join(testDir, "fooLink") + fakeSource := filepath.Join(testDir, "fooFile") + os.Remove(source) + os.Remove(fakeSource) + assert := assert.New(t) + + _, err := os.OpenFile(fakeSource, os.O_CREATE, mountPerm) + assert.NoError(err) + + err = os.Symlink(fakeSource, source) + assert.NoError(err) + + err = bindMount(context.Background(), source, "", false, "private") + assert.Error(err) +} + +func cleanupFooMount() { + dest := filepath.Join(testDir, "fooDirDest") + + syscall.Unmount(dest, 0) +} + +func TestBindMountSuccessful(t *testing.T) { + assert := assert.New(t) + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(ktu.TestDisabledNeedRoot) + } + + source := filepath.Join(testDir, "fooDirSrc") + dest := filepath.Join(testDir, "fooDirDest") + t.Cleanup(cleanupFooMount) + + err := os.MkdirAll(source, mountPerm) + assert.NoError(err) + + err = os.MkdirAll(dest, mountPerm) + assert.NoError(err) + + err = bindMount(context.Background(), source, dest, false, "private") + assert.NoError(err) +} + +func TestBindMountReadonlySuccessful(t *testing.T) { + assert := assert.New(t) + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(ktu.TestDisabledNeedRoot) + } + + source := filepath.Join(testDir, "fooDirSrc") + dest := filepath.Join(testDir, "fooDirDest") + t.Cleanup(cleanupFooMount) + + err := os.MkdirAll(source, mountPerm) + assert.NoError(err) + + err = os.MkdirAll(dest, mountPerm) + assert.NoError(err) + + err = bindMount(context.Background(), source, dest, true, "private") + assert.NoError(err) + + // should not be able to create file in read-only mount + destFile := filepath.Join(dest, "foo") + _, err = os.OpenFile(destFile, os.O_CREATE, mountPerm) + assert.Error(err) +} + +func TestBindMountInvalidPgtypes(t *testing.T) { + assert := assert.New(t) + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(ktu.TestDisabledNeedRoot) + } + + source := filepath.Join(testDir, "fooDirSrc") + dest := filepath.Join(testDir, "fooDirDest") + t.Cleanup(cleanupFooMount) + + err := os.MkdirAll(source, mountPerm) + assert.NoError(err) + + err = os.MkdirAll(dest, mountPerm) + assert.NoError(err) + + err = bindMount(context.Background(), source, dest, false, "foo") + expectedErr := fmt.Sprintf("Wrong propagation type %s", "foo") + assert.EqualError(err, expectedErr) +} + +// TestBindUnmountContainerRootfsENOENTNotError tests that if a file +// or directory attempting to be unmounted doesn't exist, then it +// is not considered an error +func TestBindUnmountContainerRootfsENOENTNotError(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("Test disabled as requires root user") + } + testMnt := "/tmp/test_mount" + sID := "sandIDTest" + cID := "contIDTest" + assert := assert.New(t) + + // Check to make sure the file doesn't exist + testPath := filepath.Join(testMnt, sID, cID, rootfsDir) + if _, err := os.Stat(testPath); !os.IsNotExist(err) { + assert.NoError(os.Remove(testPath)) + } + + err := bindUnmountContainerRootfs(context.Background(), filepath.Join(testMnt, sID), cID) + assert.NoError(err) +} + +func TestBindUnmountContainerRootfsRemoveRootfsDest(t *testing.T) { + assert := assert.New(t) + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(ktu.TestDisabledNeedRoot) + } + + sID := "sandIDTestRemoveRootfsDest" + cID := "contIDTestRemoveRootfsDest" + + testPath := filepath.Join(testDir, sID, cID, rootfsDir) + syscall.Unmount(testPath, 0) + os.Remove(testPath) + + err := os.MkdirAll(testPath, mountPerm) + assert.NoError(err) + defer os.RemoveAll(filepath.Join(testDir, sID)) + + bindUnmountContainerRootfs(context.Background(), filepath.Join(testDir, sID), cID) + + if _, err := os.Stat(testPath); err == nil { + t.Fatal("empty rootfs dest should be removed") + } else if !os.IsNotExist(err) { + t.Fatal(err) + } +} + +func TestIsHostDevice(t *testing.T) { + assert := assert.New(t) + tests := []struct { + mnt string + expected bool + }{ + {"/dev", true}, + {"/dev/zero", true}, + {"/dev/block", true}, + {"/mnt/dev/block", false}, + } + + for _, test := range tests { + result := isHostDevice(test.mnt) + assert.Equal(result, test.expected) + } +} + +func TestMajorMinorNumber(t *testing.T) { + assert := assert.New(t) + devices := []string{"/dev/zero", "/dev/net/tun"} + + for _, device := range devices { + cmdStr := fmt.Sprintf("ls -l %s | awk '{print $5$6}'", device) + cmd := exec.Command("sh", "-c", cmdStr) + output, err := cmd.Output() + assert.NoError(err) + + data := bytes.Split(output, []byte(",")) + assert.False(len(data) < 2) + + majorStr := strings.TrimSpace(string(data[0])) + minorStr := strings.TrimSpace(string(data[1])) + + majorNo, err := strconv.Atoi(majorStr) + assert.NoError(err) + + minorNo, err := strconv.Atoi(minorStr) + assert.NoError(err) + + stat := syscall.Stat_t{} + err = syscall.Stat(device, &stat) + assert.NoError(err) + + // Get major and minor numbers for the device itself. Note the use of stat.Rdev instead of Dev. + major := major(uint64(stat.Rdev)) + minor := minor(uint64(stat.Rdev)) + + assert.Equal(minor, minorNo) + assert.Equal(major, majorNo) + } +} + +func TestGetDeviceForPathValidMount(t *testing.T) { + assert := assert.New(t) + dev, err := getDeviceForPath("/proc") + assert.NoError(err) + + expected := "/proc" + + assert.Equal(dev.mountPoint, expected) +} + +func TestIsDeviceMapper(t *testing.T) { + assert := assert.New(t) + + // known major, minor for /dev/tty + major := 5 + minor := 0 + + isDM, err := isDeviceMapper(major, minor) + assert.NoError(err) + assert.False(isDM) + + // fake the block device format + blockFormatTemplate = "/sys/dev/char/%d:%d" + isDM, err = isDeviceMapper(major, minor) + assert.NoError(err) + assert.True(isDM) +} diff --git a/src/runtime/virtcontainers/mount_test.go b/src/runtime/virtcontainers/mount_test.go index 056fa2c140..6d91d22a7b 100644 --- a/src/runtime/virtcontainers/mount_test.go +++ b/src/runtime/virtcontainers/mount_test.go @@ -6,15 +6,9 @@ package virtcontainers import ( - "bytes" - "context" "fmt" "os" - "os/exec" "path/filepath" - "strconv" - "strings" - "syscall" "testing" ktu "github.com/kata-containers/kata-containers/src/runtime/pkg/katatestutils" @@ -55,24 +49,6 @@ func TestIsSystemMount(t *testing.T) { } } -func TestIsHostDevice(t *testing.T) { - assert := assert.New(t) - tests := []struct { - mnt string - expected bool - }{ - {"/dev", true}, - {"/dev/zero", true}, - {"/dev/block", true}, - {"/mnt/dev/block", false}, - } - - for _, test := range tests { - result := isHostDevice(test.mnt) - assert.Equal(result, test.expected) - } -} - func TestIsHostDeviceCreateFile(t *testing.T) { assert := assert.New(t) if tc.NotValid(ktu.NeedRoot()) { @@ -89,41 +65,6 @@ func TestIsHostDeviceCreateFile(t *testing.T) { assert.NoError(os.Remove(path)) } -func TestMajorMinorNumber(t *testing.T) { - assert := assert.New(t) - devices := []string{"/dev/zero", "/dev/net/tun"} - - for _, device := range devices { - cmdStr := fmt.Sprintf("ls -l %s | awk '{print $5$6}'", device) - cmd := exec.Command("sh", "-c", cmdStr) - output, err := cmd.Output() - assert.NoError(err) - - data := bytes.Split(output, []byte(",")) - assert.False(len(data) < 2) - - majorStr := strings.TrimSpace(string(data[0])) - minorStr := strings.TrimSpace(string(data[1])) - - majorNo, err := strconv.Atoi(majorStr) - assert.NoError(err) - - minorNo, err := strconv.Atoi(minorStr) - assert.NoError(err) - - stat := syscall.Stat_t{} - err = syscall.Stat(device, &stat) - assert.NoError(err) - - // Get major and minor numbers for the device itself. Note the use of stat.Rdev instead of Dev. - major := major(stat.Rdev) - minor := minor(stat.Rdev) - - assert.Equal(minor, minorNo) - assert.Equal(major, majorNo) - } -} - func TestGetDeviceForPathRoot(t *testing.T) { assert := assert.New(t) dev, err := getDeviceForPath("/") @@ -134,16 +75,6 @@ func TestGetDeviceForPathRoot(t *testing.T) { assert.Equal(dev.mountPoint, expected) } -func TestGetDeviceForPathValidMount(t *testing.T) { - assert := assert.New(t) - dev, err := getDeviceForPath("/proc") - assert.NoError(err) - - expected := "/proc" - - assert.Equal(dev.mountPoint, expected) -} - func TestGetDeviceForPathEmptyPath(t *testing.T) { assert := assert.New(t) _, err := getDeviceForPath("") @@ -165,64 +96,6 @@ func TestGetDeviceForPath(t *testing.T) { assert.Error(err) } -func TestGetDeviceForPathBindMount(t *testing.T) { - assert := assert.New(t) - - if tc.NotValid(ktu.NeedRoot()) { - t.Skip(ktu.TestDisabledNeedRoot) - } - - source := filepath.Join(testDir, "testDeviceDirSrc") - dest := filepath.Join(testDir, "testDeviceDirDest") - syscall.Unmount(dest, 0) - os.Remove(source) - os.Remove(dest) - - err := os.MkdirAll(source, mountPerm) - assert.NoError(err) - - defer os.Remove(source) - - err = os.MkdirAll(dest, mountPerm) - assert.NoError(err) - - defer os.Remove(dest) - - err = bindMount(context.Background(), source, dest, false, "private") - assert.NoError(err) - - defer syscall.Unmount(dest, syscall.MNT_DETACH) - - destFile := filepath.Join(dest, "test") - _, err = os.Create(destFile) - assert.NoError(err) - - defer os.Remove(destFile) - - sourceDev, _ := getDeviceForPath(source) - destDev, _ := getDeviceForPath(destFile) - - assert.Equal(sourceDev, destDev) -} - -func TestIsDeviceMapper(t *testing.T) { - assert := assert.New(t) - - // known major, minor for /dev/tty - major := 5 - minor := 0 - - isDM, err := isDeviceMapper(major, minor) - assert.NoError(err) - assert.False(isDM) - - // fake the block device format - blockFormatTemplate = "/sys/dev/char/%d:%d" - isDM, err = isDeviceMapper(major, minor) - assert.NoError(err) - assert.True(isDM) -} - func TestIsDockerVolume(t *testing.T) { assert := assert.New(t) path := "/var/lib/docker/volumes/00da1347c7cf4f15db35f/_data" @@ -234,38 +107,6 @@ func TestIsDockerVolume(t *testing.T) { assert.False(isDockerVolume) } -func TestIsEphemeralStorage(t *testing.T) { - assert := assert.New(t) - if tc.NotValid(ktu.NeedRoot()) { - t.Skip(ktu.TestDisabledNeedRoot) - } - - dir, err := os.MkdirTemp(testDir, "foo") - assert.NoError(err) - defer os.RemoveAll(dir) - - sampleEphePath := filepath.Join(dir, K8sEmptyDir, "tmp-volume") - err = os.MkdirAll(sampleEphePath, testDirMode) - assert.Nil(err) - - err = syscall.Mount("tmpfs", sampleEphePath, "tmpfs", 0, "") - assert.NoError(err) - defer syscall.Unmount(sampleEphePath, 0) - - isEphe := IsEphemeralStorage(sampleEphePath) - assert.True(isEphe) - - isHostEmptyDir := Isk8sHostEmptyDir(sampleEphePath) - assert.False(isHostEmptyDir) - - sampleEphePath = "/var/lib/kubelet/pods/366c3a75-4869-11e8-b479-507b9ddd5ce4/volumes/cache-volume" - isEphe = IsEphemeralStorage(sampleEphePath) - assert.False(isEphe) - - isHostEmptyDir = Isk8sHostEmptyDir(sampleEphePath) - assert.False(isHostEmptyDir) -} - func TestIsEmtpyDir(t *testing.T) { assert := assert.New(t) path := "/var/lib/kubelet/pods/5f0861a0-a987-4a3a-bb0f-1058ddb9678f/volumes/kubernetes.io~empty-dir/foobar" @@ -355,148 +196,3 @@ func TestIsWatchable(t *testing.T) { result = isWatchableMount(configs) assert.False(result) } - -func TestBindMountInvalidSourceSymlink(t *testing.T) { - source := filepath.Join(testDir, "fooFile") - os.Remove(source) - - err := bindMount(context.Background(), source, "", false, "private") - assert.Error(t, err) -} - -func TestBindMountFailingMount(t *testing.T) { - source := filepath.Join(testDir, "fooLink") - fakeSource := filepath.Join(testDir, "fooFile") - os.Remove(source) - os.Remove(fakeSource) - assert := assert.New(t) - - _, err := os.OpenFile(fakeSource, os.O_CREATE, mountPerm) - assert.NoError(err) - - err = os.Symlink(fakeSource, source) - assert.NoError(err) - - err = bindMount(context.Background(), source, "", false, "private") - assert.Error(err) -} - -func cleanupFooMount() { - dest := filepath.Join(testDir, "fooDirDest") - - syscall.Unmount(dest, 0) -} - -func TestBindMountSuccessful(t *testing.T) { - assert := assert.New(t) - if tc.NotValid(ktu.NeedRoot()) { - t.Skip(testDisabledAsNonRoot) - } - - source := filepath.Join(testDir, "fooDirSrc") - dest := filepath.Join(testDir, "fooDirDest") - t.Cleanup(cleanupFooMount) - - err := os.MkdirAll(source, mountPerm) - assert.NoError(err) - - err = os.MkdirAll(dest, mountPerm) - assert.NoError(err) - - err = bindMount(context.Background(), source, dest, false, "private") - assert.NoError(err) -} - -func TestBindMountReadonlySuccessful(t *testing.T) { - assert := assert.New(t) - if tc.NotValid(ktu.NeedRoot()) { - t.Skip(testDisabledAsNonRoot) - } - - source := filepath.Join(testDir, "fooDirSrc") - dest := filepath.Join(testDir, "fooDirDest") - t.Cleanup(cleanupFooMount) - - err := os.MkdirAll(source, mountPerm) - assert.NoError(err) - - err = os.MkdirAll(dest, mountPerm) - assert.NoError(err) - - err = bindMount(context.Background(), source, dest, true, "private") - assert.NoError(err) - - // should not be able to create file in read-only mount - destFile := filepath.Join(dest, "foo") - _, err = os.OpenFile(destFile, os.O_CREATE, mountPerm) - assert.Error(err) -} - -func TestBindMountInvalidPgtypes(t *testing.T) { - assert := assert.New(t) - if tc.NotValid(ktu.NeedRoot()) { - t.Skip(testDisabledAsNonRoot) - } - - source := filepath.Join(testDir, "fooDirSrc") - dest := filepath.Join(testDir, "fooDirDest") - t.Cleanup(cleanupFooMount) - - err := os.MkdirAll(source, mountPerm) - assert.NoError(err) - - err = os.MkdirAll(dest, mountPerm) - assert.NoError(err) - - err = bindMount(context.Background(), source, dest, false, "foo") - expectedErr := fmt.Sprintf("Wrong propagation type %s", "foo") - assert.EqualError(err, expectedErr) -} - -// TestBindUnmountContainerRootfsENOENTNotError tests that if a file -// or directory attempting to be unmounted doesn't exist, then it -// is not considered an error -func TestBindUnmountContainerRootfsENOENTNotError(t *testing.T) { - if os.Getuid() != 0 { - t.Skip("Test disabled as requires root user") - } - testMnt := "/tmp/test_mount" - sID := "sandIDTest" - cID := "contIDTest" - assert := assert.New(t) - - // Check to make sure the file doesn't exist - testPath := filepath.Join(testMnt, sID, cID, rootfsDir) - if _, err := os.Stat(testPath); !os.IsNotExist(err) { - assert.NoError(os.Remove(testPath)) - } - - err := bindUnmountContainerRootfs(context.Background(), filepath.Join(testMnt, sID), cID) - assert.NoError(err) -} - -func TestBindUnmountContainerRootfsRemoveRootfsDest(t *testing.T) { - assert := assert.New(t) - if tc.NotValid(ktu.NeedRoot()) { - t.Skip(ktu.TestDisabledNeedRoot) - } - - sID := "sandIDTestRemoveRootfsDest" - cID := "contIDTestRemoveRootfsDest" - - testPath := filepath.Join(testDir, sID, cID, rootfsDir) - syscall.Unmount(testPath, 0) - os.Remove(testPath) - - err := os.MkdirAll(testPath, mountPerm) - assert.NoError(err) - defer os.RemoveAll(filepath.Join(testDir, sID)) - - bindUnmountContainerRootfs(context.Background(), filepath.Join(testDir, sID), cID) - - if _, err := os.Stat(testPath); err == nil { - t.Fatal("empty rootfs dest should be removed") - } else if !os.IsNotExist(err) { - t.Fatal(err) - } -}