diff --git a/src/runtime/pkg/katautils/create.go b/src/runtime/pkg/katautils/create.go index d2c9c69cfb..3c3bf05c8d 100644 --- a/src/runtime/pkg/katautils/create.go +++ b/src/runtime/pkg/katautils/create.go @@ -162,31 +162,27 @@ func CreateSandbox(ctx context.Context, vci vc.VC, ociSpec specs.Spec, runtimeCo ociSpec.Annotations["nerdctl/network-namespace"] = sandboxConfig.NetworkConfig.NetworkID sandboxConfig.Annotations["nerdctl/network-namespace"] = ociSpec.Annotations["nerdctl/network-namespace"] - sandbox, err := vci.CreateSandbox(ctx, sandboxConfig) - if err != nil { - return nil, vc.Process{}, err - } + sandbox, err := vci.CreateSandbox(ctx, sandboxConfig, func(ctx context.Context) error { + // Run pre-start OCI hooks, in the runtime namespace. + if err := PreStartHooks(ctx, ociSpec, containerID, bundlePath); err != nil { + return err + } - hid, err := sandbox.GetHypervisorPid() + // Run create runtime OCI hooks, in the runtime namespace. + if err := CreateRuntimeHooks(ctx, ociSpec, containerID, bundlePath); err != nil { + return err + } + + return nil + }) if err != nil { return nil, vc.Process{}, err } - ctx = context.WithValue(ctx, "hypervisor-pid", hid) sid := sandbox.ID() kataUtilsLogger = kataUtilsLogger.WithField("sandbox", sid) katatrace.AddTags(span, "sandbox_id", sid) - // Run pre-start OCI hooks, in the runtime namespace. - if err := PreStartHooks(ctx, ociSpec, containerID, bundlePath); err != nil { - return nil, vc.Process{}, err - } - - // Run create runtime OCI hooks, in the runtime namespace. - if err := CreateRuntimeHooks(ctx, ociSpec, containerID, bundlePath); err != nil { - return nil, vc.Process{}, err - } - containers := sandbox.GetAllContainers() if len(containers) != 1 { return nil, vc.Process{}, fmt.Errorf("BUG: Container list from sandbox is wrong, expecting only one container, found %d containers", len(containers)) @@ -265,7 +261,7 @@ func CreateContainer(ctx context.Context, sandbox vc.VCSandbox, ociSpec specs.Sp if err != nil { return vc.Process{}, err } - ctx = context.WithValue(ctx, HypervisorPidKey{}, hid) + ctx = context.WithValue(ctx, vc.HypervisorPidKey{}, hid) // Run pre-start OCI hooks. err = EnterNetNS(sandbox.GetNetNs(), func() error { diff --git a/src/runtime/pkg/katautils/hook.go b/src/runtime/pkg/katautils/hook.go index 02a4f75973..8ed6361ae1 100644 --- a/src/runtime/pkg/katautils/hook.go +++ b/src/runtime/pkg/katautils/hook.go @@ -17,6 +17,7 @@ import ( "github.com/kata-containers/kata-containers/src/runtime/pkg/katautils/katatrace" syscallWrapper "github.com/kata-containers/kata-containers/src/runtime/pkg/syscall" + vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" ) @@ -28,8 +29,6 @@ var hookTracingTags = map[string]string{ "subsystem": "hook", } -type HypervisorPidKey struct{} - // Logger returns a logrus logger appropriate for logging hook messages func hookLogger() *logrus.Entry { return kataUtilsLogger.WithField("subsystem", "hook") @@ -40,7 +39,7 @@ func runHook(ctx context.Context, spec specs.Spec, hook specs.Hook, cid, bundleP defer span.End() katatrace.AddTags(span, "path", hook.Path, "args", hook.Args) - pid, ok := ctx.Value(HypervisorPidKey{}).(int) + pid, ok := ctx.Value(vc.HypervisorPidKey{}).(int) if !ok || pid == 0 { hookLogger().Info("no hypervisor pid") diff --git a/src/runtime/virtcontainers/api.go b/src/runtime/virtcontainers/api.go index 437c926a7e..1927d4d077 100644 --- a/src/runtime/virtcontainers/api.go +++ b/src/runtime/virtcontainers/api.go @@ -44,16 +44,16 @@ func SetLogger(ctx context.Context, logger *logrus.Entry) { // CreateSandbox is the virtcontainers sandbox creation entry point. // CreateSandbox creates a sandbox and its containers. It does not start them. -func CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig, factory Factory) (VCSandbox, error) { +func CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig, factory Factory, prestartHookFunc func(context.Context) error) (VCSandbox, error) { span, ctx := katatrace.Trace(ctx, virtLog, "CreateSandbox", apiTracingTags) defer span.End() - s, err := createSandboxFromConfig(ctx, sandboxConfig, factory) + s, err := createSandboxFromConfig(ctx, sandboxConfig, factory, prestartHookFunc) return s, err } -func createSandboxFromConfig(ctx context.Context, sandboxConfig SandboxConfig, factory Factory) (_ *Sandbox, err error) { +func createSandboxFromConfig(ctx context.Context, sandboxConfig SandboxConfig, factory Factory, prestartHookFunc func(context.Context) error) (_ *Sandbox, err error) { span, ctx := katatrace.Trace(ctx, virtLog, "createSandboxFromConfig", apiTracingTags) defer span.End() @@ -88,7 +88,7 @@ func createSandboxFromConfig(ctx context.Context, sandboxConfig SandboxConfig, f } // Start the VM - if err = s.startVM(ctx); err != nil { + if err = s.startVM(ctx, prestartHookFunc); err != nil { return nil, err } diff --git a/src/runtime/virtcontainers/implementation.go b/src/runtime/virtcontainers/implementation.go index 177797ebd2..f48e939e41 100644 --- a/src/runtime/virtcontainers/implementation.go +++ b/src/runtime/virtcontainers/implementation.go @@ -31,8 +31,8 @@ func (impl *VCImpl) SetFactory(ctx context.Context, factory Factory) { } // CreateSandbox implements the VC function of the same name. -func (impl *VCImpl) CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig) (VCSandbox, error) { - return CreateSandbox(ctx, sandboxConfig, impl.factory) +func (impl *VCImpl) CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig, hookFunc func(context.Context) error) (VCSandbox, error) { + return CreateSandbox(ctx, sandboxConfig, impl.factory, hookFunc) } // CleanupContainer is used by shimv2 to stop and delete a container exclusively, once there is no container diff --git a/src/runtime/virtcontainers/interfaces.go b/src/runtime/virtcontainers/interfaces.go index 7664f0281f..492d3f35a7 100644 --- a/src/runtime/virtcontainers/interfaces.go +++ b/src/runtime/virtcontainers/interfaces.go @@ -23,7 +23,7 @@ type VC interface { SetLogger(ctx context.Context, logger *logrus.Entry) SetFactory(ctx context.Context, factory Factory) - CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig) (VCSandbox, error) + CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig, hookFunc func(context.Context) error) (VCSandbox, error) CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error } diff --git a/src/runtime/virtcontainers/sandbox.go b/src/runtime/virtcontainers/sandbox.go index 025537fed9..9f87cc2ffe 100644 --- a/src/runtime/virtcontainers/sandbox.go +++ b/src/runtime/virtcontainers/sandbox.go @@ -92,6 +92,9 @@ var ( errSandboxNotRunning = errors.New("Sandbox not running") ) +// HypervisorPidKey is the context key for hypervisor pid +type HypervisorPidKey struct{} + // SandboxStatus describes a sandbox status. type SandboxStatus struct { ContainersStatus []ContainerStatus @@ -1194,7 +1197,7 @@ func (s *Sandbox) cleanSwap(ctx context.Context) { } // startVM starts the VM. -func (s *Sandbox) startVM(ctx context.Context) (err error) { +func (s *Sandbox) startVM(ctx context.Context, prestartHookFunc func(context.Context) error) (err error) { span, ctx := katatrace.Trace(ctx, s.Logger(), "startVM", sandboxTracingTags, map[string]string{"sandbox_id": s.id}) defer span.End() @@ -1234,9 +1237,24 @@ func (s *Sandbox) startVM(ctx context.Context) (err error) { return err } + if prestartHookFunc != nil { + hid, err := s.GetHypervisorPid() + if err != nil { + return err + } + s.Logger().Infof("hypervisor pid is %v", hid) + ctx = context.WithValue(ctx, HypervisorPidKey{}, hid) + + if err := prestartHookFunc(ctx); err != nil { + return err + } + } + // In case of vm factory, network interfaces are hotplugged // after vm is started. - if s.factory != nil { + // In case of prestartHookFunc, network config might have been changed. + // We need to rescan and handle the change. + if s.factory != nil || prestartHookFunc != nil { if _, err := s.network.AddEndpoints(ctx, s, nil, true); err != nil { return err }