diff --git a/virtcontainers/api.go b/virtcontainers/api.go index 7072926e25..3b86a03eb9 100644 --- a/virtcontainers/api.go +++ b/virtcontainers/api.go @@ -145,11 +145,11 @@ func DeleteSandbox(ctx context.Context, sandboxID string) (VCSandbox, error) { return nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() // Fetch the sandbox from storage and create it. s, err := fetchSandbox(ctx, sandboxID) @@ -178,11 +178,11 @@ func FetchSandbox(ctx context.Context, sandboxID string) (VCSandbox, error) { return nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() // Fetch the sandbox from storage and create it. s, err := fetchSandbox(ctx, sandboxID) @@ -215,11 +215,11 @@ func StartSandbox(ctx context.Context, sandboxID string) (VCSandbox, error) { return nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() // Fetch the sandbox from storage and create it. s, err := fetchSandbox(ctx, sandboxID) @@ -251,11 +251,11 @@ func StopSandbox(ctx context.Context, sandboxID string, force bool) (VCSandbox, return nil, vcTypes.ErrNeedSandbox } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() // Fetch the sandbox from storage and create it. s, err := fetchSandbox(ctx, sandboxID) @@ -290,11 +290,11 @@ func RunSandbox(ctx context.Context, sandboxConfig SandboxConfig, factory Factor } defer s.releaseStatelessSandbox() - lockFile, err := rwLockSandbox(ctx, s.id) + unlock, err := rwLockSandbox(s.id) if err != nil { return nil, err } - defer unlockSandbox(ctx, s.id, lockFile) + defer unlock() // Start the sandbox err = s.Start() @@ -310,12 +310,7 @@ func ListSandbox(ctx context.Context) ([]SandboxStatus, error) { span, ctx := trace(ctx, "ListSandbox") defer span.Finish() - var sbsdir string - if supportNewStore(ctx) { - sbsdir = fs.RunStoragePath() - } else { - sbsdir = store.RunStoragePath() - } + sbsdir := fs.RunStoragePath() dir, err := os.Open(sbsdir) if err != nil { @@ -356,15 +351,14 @@ func StatusSandbox(ctx context.Context, sandboxID string) (SandboxStatus, error) return SandboxStatus{}, vcTypes.ErrNeedSandboxID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return SandboxStatus{}, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { - unlockSandbox(ctx, sandboxID, lockFile) return SandboxStatus{}, err } defer s.releaseStatelessSandbox() @@ -402,11 +396,11 @@ func CreateContainer(ctx context.Context, sandboxID string, containerConfig Cont return nil, nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -441,11 +435,11 @@ func DeleteContainer(ctx context.Context, sandboxID, containerID string) (VCCont return nil, vcTypes.ErrNeedContainerID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -470,11 +464,11 @@ func StartContainer(ctx context.Context, sandboxID, containerID string) (VCConta return nil, vcTypes.ErrNeedContainerID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -499,11 +493,11 @@ func StopContainer(ctx context.Context, sandboxID, containerID string) (VCContai return nil, vcTypes.ErrNeedContainerID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -528,11 +522,11 @@ func EnterContainer(ctx context.Context, sandboxID, containerID string, cmd type return nil, nil, nil, vcTypes.ErrNeedContainerID } - lockFile, err := rLockSandbox(ctx, sandboxID) + unlock, err := rLockSandbox(sandboxID) if err != nil { return nil, nil, nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -562,15 +556,14 @@ func StatusContainer(ctx context.Context, sandboxID, containerID string) (Contai return ContainerStatus{}, vcTypes.ErrNeedContainerID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return ContainerStatus{}, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { - unlockSandbox(ctx, sandboxID, lockFile) return ContainerStatus{}, err } defer s.releaseStatelessSandbox() @@ -646,11 +639,11 @@ func KillContainer(ctx context.Context, sandboxID, containerID string, signal sy return vcTypes.ErrNeedContainerID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -675,11 +668,11 @@ func ProcessListContainer(ctx context.Context, sandboxID, containerID string, op return nil, vcTypes.ErrNeedContainerID } - lockFile, err := rLockSandbox(ctx, sandboxID) + unlock, err := rLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -704,11 +697,11 @@ func UpdateContainer(ctx context.Context, sandboxID, containerID string, resourc return vcTypes.ErrNeedContainerID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -732,12 +725,12 @@ func StatsContainer(ctx context.Context, sandboxID, containerID string) (Contain if containerID == "" { return ContainerStats{}, vcTypes.ErrNeedContainerID } - lockFile, err := rLockSandbox(ctx, sandboxID) + + unlock, err := rLockSandbox(sandboxID) if err != nil { return ContainerStats{}, err } - - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -758,12 +751,11 @@ func StatsSandbox(ctx context.Context, sandboxID string) (SandboxStats, []Contai return SandboxStats{}, []ContainerStats{}, vcTypes.ErrNeedSandboxID } - lockFile, err := rLockSandbox(ctx, sandboxID) + unlock, err := rLockSandbox(sandboxID) if err != nil { return SandboxStats{}, []ContainerStats{}, err } - - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -797,11 +789,11 @@ func togglePauseContainer(ctx context.Context, sandboxID, containerID string, pa return vcTypes.ErrNeedContainerID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -841,11 +833,11 @@ func AddDevice(ctx context.Context, sandboxID string, info deviceConfig.DeviceIn return nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -861,11 +853,11 @@ func toggleInterface(ctx context.Context, sandboxID string, inf *vcTypes.Interfa return nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -905,11 +897,11 @@ func ListInterfaces(ctx context.Context, sandboxID string) ([]*vcTypes.Interface return nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rLockSandbox(ctx, sandboxID) + unlock, err := rLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -929,11 +921,11 @@ func UpdateRoutes(ctx context.Context, sandboxID string, routes []*vcTypes.Route return nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -953,11 +945,11 @@ func ListRoutes(ctx context.Context, sandboxID string) ([]*vcTypes.Route, error) return nil, vcTypes.ErrNeedSandboxID } - lockFile, err := rLockSandbox(ctx, sandboxID) + unlock, err := rLockSandbox(sandboxID) if err != nil { return nil, err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { @@ -983,11 +975,11 @@ func CleanupContainer(ctx context.Context, sandboxID, containerID string, force return vcTypes.ErrNeedContainerID } - lockFile, err := rwLockSandbox(ctx, sandboxID) + unlock, err := rwLockSandbox(sandboxID) if err != nil { return err } - defer unlockSandbox(ctx, sandboxID, lockFile) + defer unlock() s, err := fetchSandbox(ctx, sandboxID) if err != nil { diff --git a/virtcontainers/persist/fs/fs.go b/virtcontainers/persist/fs/fs.go index ac8d1c48b0..d5ac624ecc 100644 --- a/virtcontainers/persist/fs/fs.go +++ b/virtcontainers/persist/fs/fs.go @@ -264,9 +264,9 @@ func (fs *FS) Lock(sandboxID string, exclusive bool) (func() error, error) { var lockType int if exclusive { - lockType = syscall.LOCK_EX | syscall.LOCK_NB + lockType = syscall.LOCK_EX } else { - lockType = syscall.LOCK_SH | syscall.LOCK_NB + lockType = syscall.LOCK_SH } if err := syscall.Flock(int(f.Fd()), lockType); err != nil { diff --git a/virtcontainers/persist/fs/fs_test.go b/virtcontainers/persist/fs/fs_test.go index 75bb4d879f..9fe889674e 100644 --- a/virtcontainers/persist/fs/fs_test.go +++ b/virtcontainers/persist/fs/fs_test.go @@ -28,7 +28,7 @@ func getFsDriver() (*FS, error) { return fs, nil } -func TestFsLock(t *testing.T) { +func TestFsLockShared(t *testing.T) { fs, err := getFsDriver() assert.Nil(t, err) assert.NotNil(t, fs) @@ -48,17 +48,42 @@ func TestFsLock(t *testing.T) { err = os.MkdirAll(sandboxDir, dirMode) assert.Nil(t, err) + // Take 2 shared locks unlockFunc, err := fs.Lock(sid, false) assert.Nil(t, err) + unlockFunc2, err := fs.Lock(sid, false) assert.Nil(t, err) - _, err = fs.Lock(sid, true) - assert.NotNil(t, err) assert.Nil(t, unlockFunc()) - // double unlock should return error - assert.NotNil(t, unlockFunc()) assert.Nil(t, unlockFunc2()) + assert.NotNil(t, unlockFunc2()) +} + +func TestFsLockExclusive(t *testing.T) { + fs, err := getFsDriver() + assert.Nil(t, err) + assert.NotNil(t, fs) + + sid := "test-fs-driver" + fs.sandboxState.SandboxContainer = sid + sandboxDir, err := fs.sandboxDir(sid) + assert.Nil(t, err) + + err = os.MkdirAll(sandboxDir, dirMode) + assert.Nil(t, err) + + // Take 1 exclusive lock + unlockFunc, err := fs.Lock(sid, true) + assert.Nil(t, err) + + assert.Nil(t, unlockFunc()) + + unlockFunc, err = fs.Lock(sid, true) + assert.Nil(t, err) + + assert.Nil(t, unlockFunc()) + assert.NotNil(t, unlockFunc()) } func TestFsDriver(t *testing.T) { diff --git a/virtcontainers/sandbox.go b/virtcontainers/sandbox.go index dea3cfe654..570dae6f97 100644 --- a/virtcontainers/sandbox.go +++ b/virtcontainers/sandbox.go @@ -604,38 +604,22 @@ func (s *Sandbox) storeSandbox() error { return nil } -func rLockSandbox(ctx context.Context, sandboxID string) (string, error) { - store, err := store.NewVCSandboxStore(ctx, sandboxID) +func rLockSandbox(sandboxID string) (func() error, error) { + store, err := persist.GetDriver("fs") if err != nil { - return "", err + return nil, fmt.Errorf("failed to get fs persist driver: %v", err) } - return store.RLock() + return store.Lock(sandboxID, false) } -func rwLockSandbox(ctx context.Context, sandboxID string) (string, error) { - store, err := store.NewVCSandboxStore(ctx, sandboxID) +func rwLockSandbox(sandboxID string) (func() error, error) { + store, err := persist.GetDriver("fs") if err != nil { - return "", err + return nil, fmt.Errorf("failed to get fs persist driver: %v", err) } - return store.Lock() -} - -func unlockSandbox(ctx context.Context, sandboxID, token string) error { - // If the store no longer exists, we won't be able to unlock. - // Creating a new store for locking an item that does not even exist - // does not make sense. - if !store.VCSandboxStoreExists(ctx, sandboxID) { - return nil - } - - store, err := store.NewVCSandboxStore(ctx, sandboxID) - if err != nil { - return err - } - - return store.Unlock(token) + return store.Lock(sandboxID, true) } func supportNewStore(ctx context.Context) bool {