runtime: remove global sandbox variable

Remove global sandbox variable, and save *Sandbox to hypervisor struct.
For some needs, hypervisor may need to use methods from Sandbox.

Signed-off-by: bin liu <bin@hyper.sh>
This commit is contained in:
bin liu 2020-11-04 16:18:09 +08:00
parent 290203943c
commit 4e3a8c0124
19 changed files with 59 additions and 50 deletions

View File

@ -312,17 +312,17 @@ func (s *service) Cleanup(ctx context.Context) (_ *taskAPI.DeleteResponse, err e
switch containerType { switch containerType {
case vc.PodSandbox: case vc.PodSandbox:
err = cleanupContainer(ctx, s.id, s.id, path) err = cleanupContainer(ctx, s.sandbox, s.id, path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
case vc.PodContainer: case vc.PodContainer:
sandboxID, err := oci.SandboxID(ociSpec) _, err := oci.SandboxID(ociSpec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = cleanupContainer(ctx, sandboxID, s.id, path) err = cleanupContainer(ctx, s.sandbox, s.id, path)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -31,10 +31,10 @@ func cReap(s *service, status int, id, execid string, exitat time.Time) {
} }
} }
func cleanupContainer(ctx context.Context, sid, cid, bundlePath string) error { func cleanupContainer(ctx context.Context, sandbox vc.VCSandbox, cid, bundlePath string) error {
shimLog.WithField("service", "cleanup").WithField("container", cid).Info("Cleanup container") shimLog.WithField("service", "cleanup").WithField("container", cid).Info("Cleanup container")
err := vci.CleanupContainer(ctx, sid, cid, true) err := vci.CleanupContainer(ctx, sandbox, cid, true)
if err != nil { if err != nil {
shimLog.WithError(err).WithField("container", cid).Warn("failed to cleanup container") shimLog.WithError(err).WithField("container", cid).Warn("failed to cleanup container")
return err return err

View File

@ -88,6 +88,7 @@ type Acrn struct {
arch acrnArch arch acrnArch
ctx context.Context ctx context.Context
store persistapi.PersistDriver store persistapi.PersistDriver
sandbox *Sandbox
} }
type acrnPlatformInfo struct { type acrnPlatformInfo struct {
@ -231,10 +232,9 @@ func (a *Acrn) appendImage(devices []Device, imagePath string) ([]Device, error)
// Get sandbox and increment the globalIndex. // Get sandbox and increment the globalIndex.
// This is to make sure the VM rootfs occupies // This is to make sure the VM rootfs occupies
// the first Index which is /dev/vda. // the first Index which is /dev/vda.
sandbox := globalSandbox
var err error var err error
if _, err = sandbox.GetAndSetSandboxBlockIndex(); err != nil { if _, err = a.sandbox.GetAndSetSandboxBlockIndex(); err != nil {
return nil, err return nil, err
} }
@ -821,3 +821,7 @@ func (a *Acrn) loadInfo() error {
func (a *Acrn) isRateLimiterBuiltin() bool { func (a *Acrn) isRateLimiterBuiltin() bool {
return false return false
} }
func (a *Acrn) setSandbox(sandbox *Sandbox) {
a.sandbox = sandbox
}

View File

@ -230,10 +230,7 @@ func TestAcrnCreateSandbox(t *testing.T) {
state: types.SandboxState{BlockIndexMap: make(map[int]struct{})}, state: types.SandboxState{BlockIndexMap: make(map[int]struct{})},
} }
globalSandbox = sandbox a.sandbox = sandbox
defer func() {
globalSandbox = nil
}()
//set PID to 1 to ignore hypercall to get UUID and set a random UUID //set PID to 1 to ignore hypercall to get UUID and set a random UUID
a.state.PID = 1 a.state.PID = 1

View File

@ -7,6 +7,7 @@ package virtcontainers
import ( import (
"context" "context"
"fmt"
"runtime" "runtime"
deviceApi "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/device/api" deviceApi "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/device/api"
@ -135,26 +136,24 @@ func createSandboxFromConfig(ctx context.Context, sandboxConfig SandboxConfig, f
// CleanupContainer is used by shimv2 to stop and delete a container exclusively, once there is no container // CleanupContainer is used by shimv2 to stop and delete a container exclusively, once there is no container
// in the sandbox left, do stop the sandbox and delete it. Those serial operations will be done exclusively by // in the sandbox left, do stop the sandbox and delete it. Those serial operations will be done exclusively by
// locking the sandbox. // locking the sandbox.
func CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error { func CleanupContainer(ctx context.Context, sandbox VCSandbox, containerID string, force bool) error {
span, ctx := trace(ctx, "CleanupContainer") span, ctx := trace(ctx, "CleanupContainer")
defer span.Finish() defer span.Finish()
if sandboxID == "" {
return vcTypes.ErrNeedSandboxID
}
if containerID == "" { if containerID == "" {
return vcTypes.ErrNeedContainerID return vcTypes.ErrNeedContainerID
} }
unlock, err := rwLockSandbox(sandboxID) s, ok := sandbox.(*Sandbox)
if !ok {
return fmt.Errorf("not a Sandbox reference")
}
unlock, err := rwLockSandbox(s.id)
if err != nil { if err != nil {
return err return err
} }
defer unlock() defer unlock()
s := globalSandbox
defer s.Release() defer s.Release()
_, err = s.StopContainer(containerID, force) _, err = s.StopContainer(containerID, force)

View File

@ -307,7 +307,7 @@ func TestCleanupContainer(t *testing.T) {
} }
for _, c := range p.GetAllContainers() { for _, c := range p.GetAllContainers() {
CleanupContainer(ctx, p.ID(), c.ID(), true) CleanupContainer(ctx, p, c.ID(), true)
} }
s, ok := p.(*Sandbox) s, ok := p.(*Sandbox)

View File

@ -1297,3 +1297,6 @@ func (clh *cloudHypervisor) vmInfo() (chclient.VmInfo, error) {
func (clh *cloudHypervisor) isRateLimiterBuiltin() bool { func (clh *cloudHypervisor) isRateLimiterBuiltin() bool {
return false return false
} }
func (clh *cloudHypervisor) setSandbox(sandbox *Sandbox) {
}

View File

@ -1247,3 +1247,6 @@ func revertBytes(num uint64) uint64 {
return 1024*revertBytes(a) + b return 1024*revertBytes(a) + b
} }
} }
func (fc *firecracker) setSandbox(sandbox *Sandbox) {
}

View File

@ -814,4 +814,6 @@ type hypervisor interface {
// check if hypervisor supports built-in rate limiter. // check if hypervisor supports built-in rate limiter.
isRateLimiterBuiltin() bool isRateLimiterBuiltin() bool
setSandbox(sandbox *Sandbox)
} }

View File

@ -38,6 +38,6 @@ func (impl *VCImpl) CreateSandbox(ctx context.Context, sandboxConfig SandboxConf
// CleanupContainer is used by shimv2 to stop and delete a container exclusively, once there is no container // CleanupContainer is used by shimv2 to stop and delete a container exclusively, once there is no container
// in the sandbox left, do stop the sandbox and delete it. Those serial operations will be done exclusively by // in the sandbox left, do stop the sandbox and delete it. Those serial operations will be done exclusively by
// locking the sandbox. // locking the sandbox.
func (impl *VCImpl) CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error { func (impl *VCImpl) CleanupContainer(ctx context.Context, sandbox VCSandbox, containerID string, force bool) error {
return CleanupContainer(ctx, sandboxID, containerID, force) return CleanupContainer(ctx, sandbox, containerID, force)
} }

View File

@ -24,7 +24,7 @@ type VC interface {
SetFactory(ctx context.Context, factory Factory) SetFactory(ctx context.Context, factory Factory)
CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig) (VCSandbox, error) CreateSandbox(ctx context.Context, sandboxConfig SandboxConfig) (VCSandbox, error)
CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error CleanupContainer(ctx context.Context, sandbox VCSandbox, containerID string, force bool) error
} }
// VCSandbox is the Sandbox interface // VCSandbox is the Sandbox interface

View File

@ -136,3 +136,6 @@ func (m *mockHypervisor) generateSocket(id string) (interface{}, error) {
func (m *mockHypervisor) isRateLimiterBuiltin() bool { func (m *mockHypervisor) isRateLimiterBuiltin() bool {
return false return false
} }
func (m *mockHypervisor) setSandbox(sandbox *Sandbox) {
}

View File

@ -18,6 +18,7 @@ package vcmock
import ( import (
"context" "context"
"fmt" "fmt"
vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers" vc "github.com/kata-containers/kata-containers/src/runtime/virtcontainers"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -49,9 +50,9 @@ func (m *VCMock) CreateSandbox(ctx context.Context, sandboxConfig vc.SandboxConf
return nil, fmt.Errorf("%s: %s (%+v): sandboxConfig: %v", mockErrorPrefix, getSelf(), m, sandboxConfig) return nil, fmt.Errorf("%s: %s (%+v): sandboxConfig: %v", mockErrorPrefix, getSelf(), m, sandboxConfig)
} }
func (m *VCMock) CleanupContainer(ctx context.Context, sandboxID, containerID string, force bool) error { func (m *VCMock) CleanupContainer(ctx context.Context, sandbox vc.VCSandbox, containerID string, force bool) error {
if m.CleanupContainerFunc != nil { if m.CleanupContainerFunc != nil {
return m.CleanupContainerFunc(ctx, sandboxID, containerID, true) return m.CleanupContainerFunc(ctx, sandbox, containerID, true)
} }
return fmt.Errorf("%s: %s (%+v): sandboxID: %v", mockErrorPrefix, getSelf(), m, sandboxID) return fmt.Errorf("%s: %s (%+v): sandbox: %v", mockErrorPrefix, getSelf(), m, sandbox)
} }

View File

@ -17,13 +17,13 @@ import (
) )
const ( const (
testSandboxID = "testSandboxID"
testContainerID = "testContainerID" testContainerID = "testContainerID"
) )
var ( var (
loggerTriggered = 0 loggerTriggered = 0
factoryTriggered = 0 factoryTriggered = 0
testSandbox vc.VCSandbox = &Sandbox{}
) )
func TestVCImplementations(t *testing.T) { func TestVCImplementations(t *testing.T) {
@ -178,21 +178,21 @@ func TestVCMockCleanupContainer(t *testing.T) {
assert.Nil(m.CleanupContainerFunc) assert.Nil(m.CleanupContainerFunc)
ctx := context.Background() ctx := context.Background()
err := m.CleanupContainer(ctx, testSandboxID, testContainerID, false) err := m.CleanupContainer(ctx, testSandbox, testContainerID, false)
assert.Error(err) assert.Error(err)
assert.True(IsMockError(err)) assert.True(IsMockError(err))
m.CleanupContainerFunc = func(ctx context.Context, sandboxID, containerID string, force bool) error { m.CleanupContainerFunc = func(ctx context.Context, sandbox vc.VCSandbox, containerID string, force bool) error {
return nil return nil
} }
err = m.CleanupContainer(ctx, testSandboxID, testContainerID, false) err = m.CleanupContainer(ctx, testSandbox, testContainerID, false)
assert.NoError(err) assert.NoError(err)
// reset // reset
m.CleanupContainerFunc = nil m.CleanupContainerFunc = nil
err = m.CleanupContainer(ctx, testSandboxID, testContainerID, false) err = m.CleanupContainer(ctx, testSandbox, testContainerID, false)
assert.Error(err) assert.Error(err)
assert.True(IsMockError(err)) assert.True(IsMockError(err))
} }
@ -204,21 +204,21 @@ func TestVCMockForceCleanupContainer(t *testing.T) {
assert.Nil(m.CleanupContainerFunc) assert.Nil(m.CleanupContainerFunc)
ctx := context.Background() ctx := context.Background()
err := m.CleanupContainer(ctx, testSandboxID, testContainerID, true) err := m.CleanupContainer(ctx, testSandbox, testContainerID, true)
assert.Error(err) assert.Error(err)
assert.True(IsMockError(err)) assert.True(IsMockError(err))
m.CleanupContainerFunc = func(ctx context.Context, sandboxID, containerID string, force bool) error { m.CleanupContainerFunc = func(ctx context.Context, sandbox vc.VCSandbox, containerID string, force bool) error {
return nil return nil
} }
err = m.CleanupContainer(ctx, testSandboxID, testContainerID, true) err = m.CleanupContainer(ctx, testSandbox, testContainerID, true)
assert.NoError(err) assert.NoError(err)
// reset // reset
m.CleanupContainerFunc = nil m.CleanupContainerFunc = nil
err = m.CleanupContainer(ctx, testSandboxID, testContainerID, true) err = m.CleanupContainer(ctx, testSandbox, testContainerID, true)
assert.Error(err) assert.Error(err)
assert.True(IsMockError(err)) assert.True(IsMockError(err))
} }

View File

@ -88,5 +88,5 @@ type VCMock struct {
SetFactoryFunc func(ctx context.Context, factory vc.Factory) SetFactoryFunc func(ctx context.Context, factory vc.Factory)
CreateSandboxFunc func(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error) CreateSandboxFunc func(ctx context.Context, sandboxConfig vc.SandboxConfig) (vc.VCSandbox, error)
CleanupContainerFunc func(ctx context.Context, sandboxID, containerID string, force bool) error CleanupContainerFunc func(ctx context.Context, sandbox vc.VCSandbox, containerID string, force bool) error
} }

View File

@ -2389,3 +2389,6 @@ func (q *qemu) generateSocket(id string) (interface{}, error) {
func (q *qemu) isRateLimiterBuiltin() bool { func (q *qemu) isRateLimiterBuiltin() bool {
return false return false
} }
func (q *qemu) setSandbox(sandbox *Sandbox) {
}

View File

@ -54,9 +54,6 @@ const (
DirMode = os.FileMode(0750) | os.ModeDir DirMode = os.FileMode(0750) | os.ModeDir
) )
// globalSandbox tracks sandbox globally
var globalSandbox *Sandbox
// SandboxStatus describes a sandbox status. // SandboxStatus describes a sandbox status.
type SandboxStatus struct { type SandboxStatus struct {
ID string ID string
@ -503,12 +500,12 @@ func newSandbox(ctx context.Context, sandboxConfig SandboxConfig, factory Factor
ctx: ctx, ctx: ctx,
} }
hypervisor.setSandbox(s)
if s.newStore, err = persist.GetDriver(); err != nil || s.newStore == nil { if s.newStore, err = persist.GetDriver(); err != nil || s.newStore == nil {
return nil, fmt.Errorf("failed to get fs persist driver: %v", err) return nil, fmt.Errorf("failed to get fs persist driver: %v", err)
} }
globalSandbox = s
defer func() { defer func() {
if retErr != nil { if retErr != nil {
s.Logger().WithError(retErr).WithField("sandboxid", s.id).Error("Create new sandbox failed") s.Logger().WithError(retErr).WithField("sandboxid", s.id).Error("Create new sandbox failed")

View File

@ -301,13 +301,11 @@ func TestSandboxSetSandboxAndContainerState(t *testing.T) {
} }
// force state to be read from disk // force state to be read from disk
p2 := globalSandbox if err := testCheckSandboxOnDiskState(p, newSandboxState); err != nil {
if err := testCheckSandboxOnDiskState(p2, newSandboxState); err != nil {
t.Error(err) t.Error(err)
} }
c2, err := p2.findContainer(contID) c2, err := p.findContainer(contID)
assert.NoError(err) assert.NoError(err)
if err := testCheckContainerOnDiskState(c2, newContainerState); err != nil { if err := testCheckContainerOnDiskState(c2, newContainerState); err != nil {

View File

@ -58,7 +58,6 @@ var testHyperstartTtySocket = ""
// cleanUp Removes any stale sandbox/container state that can affect // cleanUp Removes any stale sandbox/container state that can affect
// the next test to run. // the next test to run.
func cleanUp() { func cleanUp() {
globalSandbox = nil
os.RemoveAll(fs.MockRunStoragePath()) os.RemoveAll(fs.MockRunStoragePath())
os.RemoveAll(fs.MockRunVMStoragePath()) os.RemoveAll(fs.MockRunVMStoragePath())
syscall.Unmount(getSharePath(testSandboxID), syscall.MNT_DETACH|UmountNoFollow) syscall.Unmount(getSharePath(testSandboxID), syscall.MNT_DETACH|UmountNoFollow)