diff --git a/src/runtime/pkg/containerd-shim-v2/create_test.go b/src/runtime/pkg/containerd-shim-v2/create_test.go index 121d5ea4d..75638b518 100644 --- a/src/runtime/pkg/containerd-shim-v2/create_test.go +++ b/src/runtime/pkg/containerd-shim-v2/create_test.go @@ -41,7 +41,7 @@ func TestCreateSandboxSuccess(t *testing.T) { }, } - testingImpl.CreateSandboxFunc = func(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error) { + testingImpl.CreateSandboxFunc = func(ctx context.Context, sandboxConfig vc.SandboxConfig, hookFunc func(context.Context) error) (vc.VCSandbox, error) { return sandbox, nil } diff --git a/src/runtime/pkg/katautils/create.go b/src/runtime/pkg/katautils/create.go index ffcaa0715..3c3bf05c8 100644 --- a/src/runtime/pkg/katautils/create.go +++ b/src/runtime/pkg/katautils/create.go @@ -162,17 +162,19 @@ func CreateSandbox(ctx context.Context, vci vc.VC, ociSpec specs.Spec, runtimeCo ociSpec.Annotations["nerdctl/network-namespace"] = sandboxConfig.NetworkConfig.NetworkID sandboxConfig.Annotations["nerdctl/network-namespace"] = ociSpec.Annotations["nerdctl/network-namespace"] - // Run pre-start OCI hooks, in the runtime namespace. - if err := PreStartHooks(ctx, ociSpec, containerID, bundlePath); err != nil { - return nil, vc.Process{}, err - } + sandbox, err := vci.CreateSandbox(ctx, sandboxConfig, func(ctx context.Context) error { + // Run pre-start OCI hooks, in the runtime namespace. + if err := PreStartHooks(ctx, ociSpec, containerID, bundlePath); err != nil { + return err + } - // Run create runtime OCI hooks, in the runtime namespace. - if err := CreateRuntimeHooks(ctx, ociSpec, containerID, bundlePath); err != nil { - return nil, vc.Process{}, err - } + // Run create runtime OCI hooks, in the runtime namespace. + if err := CreateRuntimeHooks(ctx, ociSpec, containerID, bundlePath); err != nil { + return err + } - sandbox, err := vci.CreateSandbox(ctx, sandboxConfig) + return nil + }) if err != nil { return nil, vc.Process{}, err } @@ -255,6 +257,12 @@ func CreateContainer(ctx context.Context, sandbox vc.VCSandbox, ociSpec specs.Sp return vc.Process{}, err } + hid, err := sandbox.GetHypervisorPid() + if err != nil { + return vc.Process{}, err + } + ctx = context.WithValue(ctx, vc.HypervisorPidKey{}, hid) + // Run pre-start OCI hooks. err = EnterNetNS(sandbox.GetNetNs(), func() error { return PreStartHooks(ctx, ociSpec, containerID, bundlePath) diff --git a/src/runtime/pkg/katautils/create_test.go b/src/runtime/pkg/katautils/create_test.go index b1e4cf2a9..260800378 100644 --- a/src/runtime/pkg/katautils/create_test.go +++ b/src/runtime/pkg/katautils/create_test.go @@ -274,7 +274,7 @@ func TestCreateSandboxAnnotations(t *testing.T) { rootFs := vc.RootFs{Mounted: true} - testingImpl.CreateSandboxFunc = func(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error) { + testingImpl.CreateSandboxFunc = func(ctx context.Context, sandboxConfig vc.SandboxConfig, hookFunc func(context.Context) error) (vc.VCSandbox, error) { return &vcmock.Sandbox{ MockID: testSandboxID, MockContainers: []*vcmock.Container{ diff --git a/src/runtime/pkg/katautils/hook.go b/src/runtime/pkg/katautils/hook.go index 50ac95cb8..8ed6361ae 100644 --- a/src/runtime/pkg/katautils/hook.go +++ b/src/runtime/pkg/katautils/hook.go @@ -17,6 +17,7 @@ import ( "github.com/kata-containers/kata-containers/src/runtime/pkg/katautils/katatrace" syscallWrapper "github.com/kata-containers/kata-containers/src/runtime/pkg/syscall" + vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" ) @@ -38,8 +39,16 @@ func runHook(ctx context.Context, spec specs.Spec, hook specs.Hook, cid, bundleP defer span.End() katatrace.AddTags(span, "path", hook.Path, "args", hook.Args) + pid, ok := ctx.Value(vc.HypervisorPidKey{}).(int) + if !ok || pid == 0 { + hookLogger().Info("no hypervisor pid") + + pid = syscallWrapper.Gettid() + } + hookLogger().Infof("hypervisor pid %v", pid) + state := specs.State{ - Pid: syscallWrapper.Gettid(), + Pid: pid, Bundle: bundlePath, ID: cid, Annotations: spec.Annotations, diff --git a/src/runtime/virtcontainers/api.go b/src/runtime/virtcontainers/api.go index 437c926a7..1927d4d07 100644 --- a/src/runtime/virtcontainers/api.go +++ b/src/runtime/virtcontainers/api.go @@ -44,16 +44,16 @@ func SetLogger(ctx context.Context, logger *logrus.Entry) { // CreateSandbox is the virtcontainers sandbox creation entry point. // CreateSandbox creates a sandbox and its containers. It does not start them. -func CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig, factory Factory) (VCSandbox, error) { +func CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig, factory Factory, prestartHookFunc func(context.Context) error) (VCSandbox, error) { span, ctx := katatrace.Trace(ctx, virtLog, "CreateSandbox", apiTracingTags) defer span.End() - s, err := createSandboxFromConfig(ctx, sandboxConfig, factory) + s, err := createSandboxFromConfig(ctx, sandboxConfig, factory, prestartHookFunc) return s, err } -func createSandboxFromConfig(ctx context.Context, sandboxConfig SandboxConfig, factory Factory) (_ *Sandbox, err error) { +func createSandboxFromConfig(ctx context.Context, sandboxConfig SandboxConfig, factory Factory, prestartHookFunc func(context.Context) error) (_ *Sandbox, err error) { span, ctx := katatrace.Trace(ctx, virtLog, "createSandboxFromConfig", apiTracingTags) defer span.End() @@ -88,7 +88,7 @@ func createSandboxFromConfig(ctx context.Context, sandboxConfig SandboxConfig, f } // Start the VM - if err = s.startVM(ctx); err != nil { + if err = s.startVM(ctx, prestartHookFunc); err != nil { return nil, err } diff --git a/src/runtime/virtcontainers/api_test.go b/src/runtime/virtcontainers/api_test.go index 0268ab125..0af9aec5d 100644 --- a/src/runtime/virtcontainers/api_test.go +++ b/src/runtime/virtcontainers/api_test.go @@ -145,7 +145,7 @@ func TestCreateSandboxNoopAgentSuccessful(t *testing.T) { config := newTestSandboxConfigNoop() ctx := WithNewAgentFunc(context.Background(), newMockAgent) - p, err := CreateSandbox(ctx, config, nil) + p, err := CreateSandbox(ctx, config, nil, nil) assert.NoError(err) assert.NotNil(p) @@ -178,7 +178,7 @@ func TestCreateSandboxKataAgentSuccessful(t *testing.T) { defer hybridVSockTTRPCMock.Stop() ctx := WithNewAgentFunc(context.Background(), newMockAgent) - p, err := CreateSandbox(ctx, config, nil) + p, err := CreateSandbox(ctx, config, nil, nil) assert.NoError(err) assert.NotNil(p) @@ -199,7 +199,7 @@ func TestCreateSandboxFailing(t *testing.T) { config := SandboxConfig{} ctx := WithNewAgentFunc(context.Background(), newMockAgent) - p, err := CreateSandbox(ctx, config, nil) + p, err := CreateSandbox(ctx, config, nil, nil) assert.Error(err) assert.Nil(p.(*Sandbox)) } @@ -227,7 +227,7 @@ func createAndStartSandbox(ctx context.Context, config SandboxConfig) (sandbox V err error) { // Create sandbox - sandbox, err = CreateSandbox(ctx, config, nil) + sandbox, err = CreateSandbox(ctx, config, nil, nil) if sandbox == nil || err != nil { return nil, "", err } @@ -260,7 +260,7 @@ func TestReleaseSandbox(t *testing.T) { config := newTestSandboxConfigNoop() ctx := WithNewAgentFunc(context.Background(), newMockAgent) - s, err := CreateSandbox(ctx, config, nil) + s, err := CreateSandbox(ctx, config, nil, nil) assert.NoError(t, err) assert.NotNil(t, s) diff --git a/src/runtime/virtcontainers/example_pod_run_test.go b/src/runtime/virtcontainers/example_pod_run_test.go index cc12ddaed..79706b0a2 100644 --- a/src/runtime/virtcontainers/example_pod_run_test.go +++ b/src/runtime/virtcontainers/example_pod_run_test.go @@ -64,7 +64,7 @@ func Example_createAndStartSandbox() { } // Create the sandbox - s, err := vc.CreateSandbox(context.Background(), sandboxConfig, nil) + s, err := vc.CreateSandbox(context.Background(), sandboxConfig, nil, nil) if err != nil { fmt.Printf("Could not create sandbox: %s", err) return diff --git a/src/runtime/virtcontainers/implementation.go b/src/runtime/virtcontainers/implementation.go index 177797ebd..f48e939e4 100644 --- a/src/runtime/virtcontainers/implementation.go +++ b/src/runtime/virtcontainers/implementation.go @@ -31,8 +31,8 @@ func (impl *VCImpl) SetFactory(ctx context.Context, factory Factory) { } // CreateSandbox implements the VC function of the same name. -func (impl *VCImpl) CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig) (VCSandbox, error) { - return CreateSandbox(ctx, sandboxConfig, impl.factory) +func (impl *VCImpl) CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig, hookFunc func(context.Context) error) (VCSandbox, error) { + return CreateSandbox(ctx, sandboxConfig, impl.factory, hookFunc) } // CleanupContainer is used by shimv2 to stop and delete a container exclusively, once there is no container diff --git a/src/runtime/virtcontainers/interfaces.go b/src/runtime/virtcontainers/interfaces.go index 7664f0281..492d3f35a 100644 --- a/src/runtime/virtcontainers/interfaces.go +++ b/src/runtime/virtcontainers/interfaces.go @@ -23,7 +23,7 @@ type VC interface { SetLogger(ctx context.Context, logger *logrus.Entry) SetFactory(ctx context.Context, factory Factory) - CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig) (VCSandbox, error) + CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig, hookFunc func(context.Context) error) (VCSandbox, error) CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error } diff --git a/src/runtime/virtcontainers/network_linux.go b/src/runtime/virtcontainers/network_linux.go index 3d9f7b7d2..30819bb30 100644 --- a/src/runtime/virtcontainers/network_linux.go +++ b/src/runtime/virtcontainers/network_linux.go @@ -252,6 +252,22 @@ func (n *LinuxNetwork) removeSingleEndpoint(ctx context.Context, s *Sandbox, idx return nil } +func (n *LinuxNetwork) endpointAlreadyAdded(netInfo *NetworkInfo) bool { + for _, ep := range n.eps { + // Existing endpoint + if ep.Name() == netInfo.Iface.Name { + return true + } + pair := ep.NetworkPair() + // Existing virtual endpoints + if pair != nil && (pair.TapInterface.Name == netInfo.Iface.Name || pair.TapInterface.TAPIface.Name == netInfo.Iface.Name || pair.VirtIface.Name == netInfo.Iface.Name) { + return true + } + } + + return false +} + // Scan the networking namespace through netlink and then: // 1. Create the endpoints for the relevant interfaces found there. // 2. Attach them to the VM. @@ -292,6 +308,12 @@ func (n *LinuxNetwork) addAllEndpoints(ctx context.Context, s *Sandbox, hotplug continue } + // Skip any interfaces that are already added + if n.endpointAlreadyAdded(&netInfo) { + networkLogger().WithField("endpoint", netInfo.Iface.Name).Info("already added") + continue + } + if err := doNetNS(n.netNSPath, func(_ ns.NetNS) error { _, err = n.addSingleEndpoint(ctx, s, netInfo, hotplug) return err diff --git a/src/runtime/virtcontainers/pkg/vcmock/mock.go b/src/runtime/virtcontainers/pkg/vcmock/mock.go index 3b1815166..39305e244 100644 --- a/src/runtime/virtcontainers/pkg/vcmock/mock.go +++ b/src/runtime/virtcontainers/pkg/vcmock/mock.go @@ -42,9 +42,9 @@ func (m *VCMock) SetFactory(ctx context.Context, factory vc.Factory) { } // CreateSandbox implements the VC function of the same name. -func (m *VCMock) CreateSandbox(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error) { +func (m *VCMock) CreateSandbox(ctx context.Context, sandboxConfig vc.SandboxConfig, hookFunc func(context.Context) error) (vc.VCSandbox, error) { if m.CreateSandboxFunc != nil { - return m.CreateSandboxFunc(ctx, sandboxConfig) + return m.CreateSandboxFunc(ctx, sandboxConfig, hookFunc) } return nil, fmt.Errorf("%s: %s (%+v): sandboxConfig: %v", mockErrorPrefix, getSelf(), m, sandboxConfig) diff --git a/src/runtime/virtcontainers/pkg/vcmock/mock_test.go b/src/runtime/virtcontainers/pkg/vcmock/mock_test.go index 9043b168d..7558b8d1b 100644 --- a/src/runtime/virtcontainers/pkg/vcmock/mock_test.go +++ b/src/runtime/virtcontainers/pkg/vcmock/mock_test.go @@ -120,22 +120,22 @@ func TestVCMockCreateSandbox(t *testing.T) { assert.Nil(m.CreateSandboxFunc) ctx := context.Background() - _, err := m.CreateSandbox(ctx, vc.SandboxConfig{}) + _, err := m.CreateSandbox(ctx, vc.SandboxConfig{}, nil) assert.Error(err) assert.True(IsMockError(err)) - m.CreateSandboxFunc = func(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error) { + m.CreateSandboxFunc = func(ctx context.Context, sandboxConfig vc.SandboxConfig, hookFunc func(context.Context) error) (vc.VCSandbox, error) { return &Sandbox{}, nil } - sandbox, err := m.CreateSandbox(ctx, vc.SandboxConfig{}) + sandbox, err := m.CreateSandbox(ctx, vc.SandboxConfig{}, nil) assert.NoError(err) assert.Equal(sandbox, &Sandbox{}) // reset m.CreateSandboxFunc = nil - _, err = m.CreateSandbox(ctx, vc.SandboxConfig{}) + _, err = m.CreateSandbox(ctx, vc.SandboxConfig{}, nil) assert.Error(err) assert.True(IsMockError(err)) } diff --git a/src/runtime/virtcontainers/pkg/vcmock/types.go b/src/runtime/virtcontainers/pkg/vcmock/types.go index 05a0a9859..16b811cd5 100644 --- a/src/runtime/virtcontainers/pkg/vcmock/types.go +++ b/src/runtime/virtcontainers/pkg/vcmock/types.go @@ -88,6 +88,6 @@ type VCMock struct { SetLoggerFunc func(ctx context.Context, logger *logrus.Entry) SetFactoryFunc func(ctx context.Context, factory vc.Factory) - CreateSandboxFunc func(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error) + CreateSandboxFunc func(ctx context.Context, sandboxConfig vc.SandboxConfig, hookFunc func(context.Context) error) (vc.VCSandbox, error) CleanupContainerFunc func(ctx context.Context, sandboxID, containerID string, force bool) error } diff --git a/src/runtime/virtcontainers/sandbox.go b/src/runtime/virtcontainers/sandbox.go index 025537fed..9f87cc2ff 100644 --- a/src/runtime/virtcontainers/sandbox.go +++ b/src/runtime/virtcontainers/sandbox.go @@ -92,6 +92,9 @@ var ( errSandboxNotRunning = errors.New("Sandbox not running") ) +// HypervisorPidKey is the context key for hypervisor pid +type HypervisorPidKey struct{} + // SandboxStatus describes a sandbox status. type SandboxStatus struct { ContainersStatus []ContainerStatus @@ -1194,7 +1197,7 @@ func (s *Sandbox) cleanSwap(ctx context.Context) { } // startVM starts the VM. -func (s *Sandbox) startVM(ctx context.Context) (err error) { +func (s *Sandbox) startVM(ctx context.Context, prestartHookFunc func(context.Context) error) (err error) { span, ctx := katatrace.Trace(ctx, s.Logger(), "startVM", sandboxTracingTags, map[string]string{"sandbox_id": s.id}) defer span.End() @@ -1234,9 +1237,24 @@ func (s *Sandbox) startVM(ctx context.Context) (err error) { return err } + if prestartHookFunc != nil { + hid, err := s.GetHypervisorPid() + if err != nil { + return err + } + s.Logger().Infof("hypervisor pid is %v", hid) + ctx = context.WithValue(ctx, HypervisorPidKey{}, hid) + + if err := prestartHookFunc(ctx); err != nil { + return err + } + } + // In case of vm factory, network interfaces are hotplugged // after vm is started. - if s.factory != nil { + // In case of prestartHookFunc, network config might have been changed. + // We need to rescan and handle the change. + if s.factory != nil || prestartHookFunc != nil { if _, err := s.network.AddEndpoints(ctx, s, nil, true); err != nil { return err } diff --git a/src/runtime/virtcontainers/sandbox_test.go b/src/runtime/virtcontainers/sandbox_test.go index 59ed24c1a..3d5154eae 100644 --- a/src/runtime/virtcontainers/sandbox_test.go +++ b/src/runtime/virtcontainers/sandbox_test.go @@ -1348,7 +1348,7 @@ func TestSandboxCreationFromConfigRollbackFromCreateSandbox(t *testing.T) { // Ensure hypervisor doesn't exist assert.NoError(os.Remove(hConf.HypervisorPath)) - _, err := createSandboxFromConfig(ctx, sConf, nil) + _, err := createSandboxFromConfig(ctx, sConf, nil, nil) // Fail at createSandbox: QEMU path does not exist, it is expected. Then rollback is called assert.Error(err)