diff --git a/src/runtime/containerd-shim-v2/service.go b/src/runtime/containerd-shim-v2/service.go index 795f27e4b5..2903824779 100644 --- a/src/runtime/containerd-shim-v2/service.go +++ b/src/runtime/containerd-shim-v2/service.go @@ -312,17 +312,17 @@ func (s *service) Cleanup(ctx context.Context) (_ *taskAPI.DeleteResponse, err e switch containerType { case vc.PodSandbox: - err = cleanupContainer(ctx, s.id, s.id, path) + err = cleanupContainer(ctx, s.sandbox, s.id, path) if err != nil { return nil, err } case vc.PodContainer: - sandboxID, err := oci.SandboxID(ociSpec) + _, err := oci.SandboxID(ociSpec) if err != nil { return nil, err } - err = cleanupContainer(ctx, sandboxID, s.id, path) + err = cleanupContainer(ctx, s.sandbox, s.id, path) if err != nil { return nil, err } diff --git a/src/runtime/containerd-shim-v2/utils.go b/src/runtime/containerd-shim-v2/utils.go index 95adb93faa..3172483500 100644 --- a/src/runtime/containerd-shim-v2/utils.go +++ b/src/runtime/containerd-shim-v2/utils.go @@ -31,10 +31,10 @@ func cReap(s *service, status int, id, execid string, exitat time.Time) { } } -func cleanupContainer(ctx context.Context, sid, cid, bundlePath string) error { +func cleanupContainer(ctx context.Context, sandbox vc.VCSandbox, cid, bundlePath string) error { shimLog.WithField("service", "cleanup").WithField("container", cid).Info("Cleanup container") - err := vci.CleanupContainer(ctx, sid, cid, true) + err := vci.CleanupContainer(ctx, sandbox, cid, true) if err != nil { shimLog.WithError(err).WithField("container", cid).Warn("failed to cleanup container") return err diff --git a/src/runtime/virtcontainers/acrn.go b/src/runtime/virtcontainers/acrn.go index 5b6010d211..b37f65bb06 100644 --- a/src/runtime/virtcontainers/acrn.go +++ b/src/runtime/virtcontainers/acrn.go @@ -88,6 +88,7 @@ type Acrn struct { arch acrnArch ctx context.Context store persistapi.PersistDriver + sandbox *Sandbox } type acrnPlatformInfo struct { @@ -231,10 +232,9 @@ func (a *Acrn) appendImage(devices []Device, imagePath string) ([]Device, error) // Get sandbox and increment the globalIndex. // This is to make sure the VM rootfs occupies // the first Index which is /dev/vda. - sandbox := globalSandbox var err error - if _, err = sandbox.GetAndSetSandboxBlockIndex(); err != nil { + if _, err = a.sandbox.GetAndSetSandboxBlockIndex(); err != nil { return nil, err } @@ -821,3 +821,7 @@ func (a *Acrn) loadInfo() error { func (a *Acrn) isRateLimiterBuiltin() bool { return false } + +func (a *Acrn) setSandbox(sandbox *Sandbox) { + a.sandbox = sandbox +} diff --git a/src/runtime/virtcontainers/acrn_test.go b/src/runtime/virtcontainers/acrn_test.go index 27dbd088e7..eb357b0034 100644 --- a/src/runtime/virtcontainers/acrn_test.go +++ b/src/runtime/virtcontainers/acrn_test.go @@ -230,10 +230,7 @@ func TestAcrnCreateSandbox(t *testing.T) { state: types.SandboxState{BlockIndexMap: make(map[int]struct{})}, } - globalSandbox = sandbox - defer func() { - globalSandbox = nil - }() + a.sandbox = sandbox //set PID to 1 to ignore hypercall to get UUID and set a random UUID a.state.PID = 1 diff --git a/src/runtime/virtcontainers/api.go b/src/runtime/virtcontainers/api.go index 0335059f4d..5c3f402108 100644 --- a/src/runtime/virtcontainers/api.go +++ b/src/runtime/virtcontainers/api.go @@ -7,6 +7,7 @@ package virtcontainers import ( "context" + "fmt" "runtime" deviceApi "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/device/api" @@ -135,26 +136,24 @@ func createSandboxFromConfig(ctx context.Context, sandboxConfig SandboxConfig, f // CleanupContainer is used by shimv2 to stop and delete a container exclusively, once there is no container // in the sandbox left, do stop the sandbox and delete it. Those serial operations will be done exclusively by // locking the sandbox. -func CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error { +func CleanupContainer(ctx context.Context, sandbox VCSandbox, containerID string, force bool) error { span, ctx := trace(ctx, "CleanupContainer") defer span.Finish() - - if sandboxID == "" { - return vcTypes.ErrNeedSandboxID - } - if containerID == "" { return vcTypes.ErrNeedContainerID } - unlock, err := rwLockSandbox(sandboxID) + s, ok := sandbox.(*Sandbox) + if !ok { + return fmt.Errorf("not a Sandbox reference") + } + + unlock, err := rwLockSandbox(s.id) if err != nil { return err } defer unlock() - s := globalSandbox - defer s.Release() _, err = s.StopContainer(containerID, force) diff --git a/src/runtime/virtcontainers/api_test.go b/src/runtime/virtcontainers/api_test.go index 927e450c8d..f0447b3ffb 100644 --- a/src/runtime/virtcontainers/api_test.go +++ b/src/runtime/virtcontainers/api_test.go @@ -307,7 +307,7 @@ func TestCleanupContainer(t *testing.T) { } for _, c := range p.GetAllContainers() { - CleanupContainer(ctx, p.ID(), c.ID(), true) + CleanupContainer(ctx, p, c.ID(), true) } s, ok := p.(*Sandbox) diff --git a/src/runtime/virtcontainers/clh.go b/src/runtime/virtcontainers/clh.go index 30aedd4c02..6125d16795 100644 --- a/src/runtime/virtcontainers/clh.go +++ b/src/runtime/virtcontainers/clh.go @@ -1297,3 +1297,6 @@ func (clh *cloudHypervisor) vmInfo() (chclient.VmInfo, error) { func (clh *cloudHypervisor) isRateLimiterBuiltin() bool { return false } + +func (clh *cloudHypervisor) setSandbox(sandbox *Sandbox) { +} diff --git a/src/runtime/virtcontainers/fc.go b/src/runtime/virtcontainers/fc.go index 37fdbdc105..8450006904 100644 --- a/src/runtime/virtcontainers/fc.go +++ b/src/runtime/virtcontainers/fc.go @@ -1247,3 +1247,6 @@ func revertBytes(num uint64) uint64 { return 1024*revertBytes(a) + b } } + +func (fc *firecracker) setSandbox(sandbox *Sandbox) { +} diff --git a/src/runtime/virtcontainers/hypervisor.go b/src/runtime/virtcontainers/hypervisor.go index ee7a2b6b73..75834e8da7 100644 --- a/src/runtime/virtcontainers/hypervisor.go +++ b/src/runtime/virtcontainers/hypervisor.go @@ -814,4 +814,6 @@ type hypervisor interface { // check if hypervisor supports built-in rate limiter. isRateLimiterBuiltin() bool + + setSandbox(sandbox *Sandbox) } diff --git a/src/runtime/virtcontainers/implementation.go b/src/runtime/virtcontainers/implementation.go index 177797ebd2..186b59e87a 100644 --- a/src/runtime/virtcontainers/implementation.go +++ b/src/runtime/virtcontainers/implementation.go @@ -38,6 +38,6 @@ func (impl *VCImpl) CreateSandbox(ctx context.Context, sandboxConfig SandboxConf // CleanupContainer is used by shimv2 to stop and delete a container exclusively, once there is no container // in the sandbox left, do stop the sandbox and delete it. Those serial operations will be done exclusively by // locking the sandbox. -func (impl *VCImpl) CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error { - return CleanupContainer(ctx, sandboxID, containerID, force) +func (impl *VCImpl) CleanupContainer(ctx context.Context, sandbox VCSandbox, containerID string, force bool) error { + return CleanupContainer(ctx, sandbox, containerID, force) } diff --git a/src/runtime/virtcontainers/interfaces.go b/src/runtime/virtcontainers/interfaces.go index 365c329db0..8aeed1498f 100644 --- a/src/runtime/virtcontainers/interfaces.go +++ b/src/runtime/virtcontainers/interfaces.go @@ -24,7 +24,7 @@ type VC interface { SetFactory(ctx context.Context, factory Factory) CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig) (VCSandbox, error) - CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error + CleanupContainer(ctx context.Context, sandbox VCSandbox, containerID string, force bool) error } // VCSandbox is the Sandbox interface diff --git a/src/runtime/virtcontainers/mock_hypervisor.go b/src/runtime/virtcontainers/mock_hypervisor.go index 5dff86fcbd..d7a2480c2b 100644 --- a/src/runtime/virtcontainers/mock_hypervisor.go +++ b/src/runtime/virtcontainers/mock_hypervisor.go @@ -136,3 +136,6 @@ func (m *mockHypervisor) generateSocket(id string) (interface{}, error) { func (m *mockHypervisor) isRateLimiterBuiltin() bool { return false } + +func (m *mockHypervisor) setSandbox(sandbox *Sandbox) { +} diff --git a/src/runtime/virtcontainers/pkg/vcmock/mock.go b/src/runtime/virtcontainers/pkg/vcmock/mock.go index 5a13c76c45..0dd2b579f6 100644 --- a/src/runtime/virtcontainers/pkg/vcmock/mock.go +++ b/src/runtime/virtcontainers/pkg/vcmock/mock.go @@ -18,6 +18,7 @@ package vcmock import ( "context" "fmt" + vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" "github.com/sirupsen/logrus" ) @@ -49,9 +50,9 @@ func (m *VCMock) CreateSandbox(ctx context.Context, sandboxConfig vc.SandboxConf return nil, fmt.Errorf("%s: %s (%+v): sandboxConfig: %v", mockErrorPrefix, getSelf(), m, sandboxConfig) } -func (m *VCMock) CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error { +func (m *VCMock) CleanupContainer(ctx context.Context, sandbox vc.VCSandbox, containerID string, force bool) error { if m.CleanupContainerFunc != nil { - return m.CleanupContainerFunc(ctx, sandboxID, containerID, true) + return m.CleanupContainerFunc(ctx, sandbox, containerID, true) } - return fmt.Errorf("%s: %s (%+v): sandboxID: %v", mockErrorPrefix, getSelf(), m, sandboxID) + return fmt.Errorf("%s: %s (%+v): sandbox: %v", mockErrorPrefix, getSelf(), m, sandbox) } diff --git a/src/runtime/virtcontainers/pkg/vcmock/mock_test.go b/src/runtime/virtcontainers/pkg/vcmock/mock_test.go index 9043b168da..79c543b3cd 100644 --- a/src/runtime/virtcontainers/pkg/vcmock/mock_test.go +++ b/src/runtime/virtcontainers/pkg/vcmock/mock_test.go @@ -17,13 +17,13 @@ import ( ) const ( - testSandboxID = "testSandboxID" testContainerID = "testContainerID" ) var ( - loggerTriggered = 0 - factoryTriggered = 0 + loggerTriggered = 0 + factoryTriggered = 0 + testSandbox vc.VCSandbox = &Sandbox{} ) func TestVCImplementations(t *testing.T) { @@ -178,21 +178,21 @@ func TestVCMockCleanupContainer(t *testing.T) { assert.Nil(m.CleanupContainerFunc) ctx := context.Background() - err := m.CleanupContainer(ctx, testSandboxID, testContainerID, false) + err := m.CleanupContainer(ctx, testSandbox, testContainerID, false) assert.Error(err) assert.True(IsMockError(err)) - m.CleanupContainerFunc = func(ctx context.Context, sandboxID, containerID string, force bool) error { + m.CleanupContainerFunc = func(ctx context.Context, sandbox vc.VCSandbox, containerID string, force bool) error { return nil } - err = m.CleanupContainer(ctx, testSandboxID, testContainerID, false) + err = m.CleanupContainer(ctx, testSandbox, testContainerID, false) assert.NoError(err) // reset m.CleanupContainerFunc = nil - err = m.CleanupContainer(ctx, testSandboxID, testContainerID, false) + err = m.CleanupContainer(ctx, testSandbox, testContainerID, false) assert.Error(err) assert.True(IsMockError(err)) } @@ -204,21 +204,21 @@ func TestVCMockForceCleanupContainer(t *testing.T) { assert.Nil(m.CleanupContainerFunc) ctx := context.Background() - err := m.CleanupContainer(ctx, testSandboxID, testContainerID, true) + err := m.CleanupContainer(ctx, testSandbox, testContainerID, true) assert.Error(err) assert.True(IsMockError(err)) - m.CleanupContainerFunc = func(ctx context.Context, sandboxID, containerID string, force bool) error { + m.CleanupContainerFunc = func(ctx context.Context, sandbox vc.VCSandbox, containerID string, force bool) error { return nil } - err = m.CleanupContainer(ctx, testSandboxID, testContainerID, true) + err = m.CleanupContainer(ctx, testSandbox, testContainerID, true) assert.NoError(err) // reset m.CleanupContainerFunc = nil - err = m.CleanupContainer(ctx, testSandboxID, testContainerID, true) + err = m.CleanupContainer(ctx, testSandbox, testContainerID, true) assert.Error(err) assert.True(IsMockError(err)) } diff --git a/src/runtime/virtcontainers/pkg/vcmock/types.go b/src/runtime/virtcontainers/pkg/vcmock/types.go index 9caa53232a..de7f034421 100644 --- a/src/runtime/virtcontainers/pkg/vcmock/types.go +++ b/src/runtime/virtcontainers/pkg/vcmock/types.go @@ -87,6 +87,6 @@ type VCMock struct { SetLoggerFunc func(ctx context.Context, logger *logrus.Entry) SetFactoryFunc func(ctx context.Context, factory vc.Factory) - CreateSandboxFunc func(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error) - CleanupContainerFunc func(ctx context.Context, sandboxID, containerID string, force bool) error + CreateSandboxFunc func(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error) + CleanupContainerFunc func(ctx context.Context, sandbox vc.VCSandbox, containerID string, force bool) error } diff --git a/src/runtime/virtcontainers/qemu.go b/src/runtime/virtcontainers/qemu.go index 932139aa01..9d5366611b 100644 --- a/src/runtime/virtcontainers/qemu.go +++ b/src/runtime/virtcontainers/qemu.go @@ -2389,3 +2389,6 @@ func (q *qemu) generateSocket(id string) (interface{}, error) { func (q *qemu) isRateLimiterBuiltin() bool { return false } + +func (q *qemu) setSandbox(sandbox *Sandbox) { +} diff --git a/src/runtime/virtcontainers/sandbox.go b/src/runtime/virtcontainers/sandbox.go index 811bed0205..118ece60e0 100644 --- a/src/runtime/virtcontainers/sandbox.go +++ b/src/runtime/virtcontainers/sandbox.go @@ -54,9 +54,6 @@ const ( DirMode = os.FileMode(0750) | os.ModeDir ) -// globalSandbox tracks sandbox globally -var globalSandbox *Sandbox - // SandboxStatus describes a sandbox status. type SandboxStatus struct { ID string @@ -503,12 +500,12 @@ func newSandbox(ctx context.Context, sandboxConfig SandboxConfig, factory Factor ctx: ctx, } + hypervisor.setSandbox(s) + if s.newStore, err = persist.GetDriver(); err != nil || s.newStore == nil { return nil, fmt.Errorf("failed to get fs persist driver: %v", err) } - globalSandbox = s - defer func() { if retErr != nil { s.Logger().WithError(retErr).WithField("sandboxid", s.id).Error("Create new sandbox failed") diff --git a/src/runtime/virtcontainers/sandbox_test.go b/src/runtime/virtcontainers/sandbox_test.go index 110cd8e9b3..24458a2117 100644 --- a/src/runtime/virtcontainers/sandbox_test.go +++ b/src/runtime/virtcontainers/sandbox_test.go @@ -301,13 +301,11 @@ func TestSandboxSetSandboxAndContainerState(t *testing.T) { } // force state to be read from disk - p2 := globalSandbox - - if err := testCheckSandboxOnDiskState(p2, newSandboxState); err != nil { + if err := testCheckSandboxOnDiskState(p, newSandboxState); err != nil { t.Error(err) } - c2, err := p2.findContainer(contID) + c2, err := p.findContainer(contID) assert.NoError(err) if err := testCheckContainerOnDiskState(c2, newContainerState); err != nil { diff --git a/src/runtime/virtcontainers/virtcontainers_test.go b/src/runtime/virtcontainers/virtcontainers_test.go index 6960ffc106..cbec355eaf 100644 --- a/src/runtime/virtcontainers/virtcontainers_test.go +++ b/src/runtime/virtcontainers/virtcontainers_test.go @@ -58,7 +58,6 @@ var testHyperstartTtySocket = "" // cleanUp Removes any stale sandbox/container state that can affect // the next test to run. func cleanUp() { - globalSandbox = nil os.RemoveAll(fs.MockRunStoragePath()) os.RemoveAll(fs.MockRunVMStoragePath()) syscall.Unmount(getSharePath(testSandboxID), syscall.MNT_DETACH|UmountNoFollow)