diff --git a/src/runtime/virtcontainers/clh.go b/src/runtime/virtcontainers/clh.go index 6295031cdf..ed9be694d0 100644 --- a/src/runtime/virtcontainers/clh.go +++ b/src/runtime/virtcontainers/clh.go @@ -362,7 +362,9 @@ func (clh *cloudHypervisor) startSandbox(ctx context.Context, timeout int) error if clh.config.SharedFS == config.VirtioFS { clh.Logger().WithField("function", "startSandbox").Info("Starting virtiofsd") - pid, err := clh.virtiofsd.Start(ctx) + pid, err := clh.virtiofsd.Start(ctx, func() { + clh.stopSandbox(ctx, false) + }) if err != nil { return err } diff --git a/src/runtime/virtcontainers/qemu.go b/src/runtime/virtcontainers/qemu.go index cf1c4377ee..74ec8a8363 100644 --- a/src/runtime/virtcontainers/qemu.go +++ b/src/runtime/virtcontainers/qemu.go @@ -13,9 +13,7 @@ import ( "fmt" "io/ioutil" "math" - "net" "os" - "os/exec" "path/filepath" "strconv" "strings" @@ -106,6 +104,8 @@ type qemu struct { // if in memory dump progress memoryDumpFlag sync.Mutex + + virtiofsd Virtiofsd } const ( @@ -632,6 +632,20 @@ func (q *qemu) createSandbox(ctx context.Context, id string, networkNS NetworkNa q.qemuConfig = qemuConfig + virtiofsdSocketPath, err := q.vhostFSSocketPath(q.id) + if err != nil { + return err + } + + q.virtiofsd = &virtiofsd{ + path: q.config.VirtioFSDaemon, + sourcePath: filepath.Join(getSharePath(q.id)), + socketPath: virtiofsdSocketPath, + extraArgs: q.config.VirtioFSExtraArgs, + debug: q.config.Debug, + cache: q.config.VirtioFSCache, + } + return nil } @@ -639,102 +653,29 @@ func (q *qemu) vhostFSSocketPath(id string) (string, error) { return utils.BuildSocketPath(q.store.RunVMStoragePath(), id, vhostFSSocket) } -func (q *qemu) virtiofsdArgs(fd uintptr) []string { - // The daemon will terminate when the vhost-user socket - // connection with QEMU closes. Therefore we do not keep track - // of this child process after returning from this function. - sourcePath := filepath.Join(getSharePath(q.id)) - args := []string{ - fmt.Sprintf("--fd=%v", fd), - "-o", "source=" + sourcePath, - "-o", "cache=" + q.config.VirtioFSCache, - "--syslog", "-o", "no_posix_lock"} - if q.config.Debug { - args = append(args, "-d") - } else { - args = append(args, "-f") - } - - if len(q.config.VirtioFSExtraArgs) != 0 { - args = append(args, q.config.VirtioFSExtraArgs...) - } - return args -} - func (q *qemu) setupVirtiofsd(ctx context.Context) (err error) { - var listener *net.UnixListener - var fd *os.File - - sockPath, err := q.vhostFSSocketPath(q.id) - if err != nil { - return err - } - - listener, err = net.ListenUnix("unix", &net.UnixAddr{ - Name: sockPath, - Net: "unix", + pid, err := q.virtiofsd.Start(ctx, func() { + q.stopSandbox(ctx, false) }) if err != nil { return err } - listener.SetUnlinkOnClose(false) + q.state.VirtiofsdPid = pid - fd, err = listener.File() - listener.Close() // no longer needed since fd is a dup - listener = nil - if err != nil { - return err - } - defer fd.Close() - - const sockFd = 3 // Cmd.ExtraFiles[] fds are numbered starting from 3 - cmd := exec.Command(q.config.VirtioFSDaemon, q.virtiofsdArgs(sockFd)...) - cmd.ExtraFiles = append(cmd.ExtraFiles, fd) - stderr, err := cmd.StderrPipe() - if err != nil { - return err - } - - err = cmd.Start() - if err != nil { - return fmt.Errorf("virtiofs daemon %v returned with error: %v", q.config.VirtioFSDaemon, err) - } - q.state.VirtiofsdPid = cmd.Process.Pid - - // Monitor virtiofsd's stderr and stop sandbox if virtiofsd quits - go func() { - scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - q.Logger().WithField("source", "virtiofsd").Info(scanner.Text()) - } - q.Logger().Info("virtiofsd quits") - // Wait to release resources of virtiofsd process - cmd.Process.Wait() - q.stopSandbox(ctx, false) - }() - return err + return nil } func (q *qemu) stopVirtiofsd(ctx context.Context) (err error) { - - // kill virtiofsd if q.state.VirtiofsdPid == 0 { return errors.New("invalid virtiofsd PID(0)") } - err = syscall.Kill(q.state.VirtiofsdPid, syscall.SIGKILL) + err = q.virtiofsd.Stop(ctx) if err != nil { return err } q.state.VirtiofsdPid = 0 - - // remove virtiofsd socket - sockPath, err := q.vhostFSSocketPath(q.id) - if err != nil { - return err - } - - return os.Remove(sockPath) + return nil } func (q *qemu) getMemArgs() (bool, string, string, error) { diff --git a/src/runtime/virtcontainers/qemu_test.go b/src/runtime/virtcontainers/qemu_test.go index 3e8d0b95e5..645842c11d 100644 --- a/src/runtime/virtcontainers/qemu_test.go +++ b/src/runtime/virtcontainers/qemu_test.go @@ -11,7 +11,6 @@ import ( "io/ioutil" "os" "path/filepath" - "strings" "testing" govmmQemu "github.com/kata-containers/govmm/qemu" @@ -550,35 +549,6 @@ func createQemuSandboxConfig() (*Sandbox, error) { return &sandbox, nil } -func TestQemuVirtiofsdArgs(t *testing.T) { - assert := assert.New(t) - - q := &qemu{ - id: "foo", - config: HypervisorConfig{ - VirtioFSCache: "none", - Debug: true, - }, - } - - savedKataHostSharedDir := kataHostSharedDir - kataHostSharedDir = func() string { - return "test-share-dir" - } - defer func() { - kataHostSharedDir = savedKataHostSharedDir - }() - - result := "--fd=123 -o source=test-share-dir/foo/shared -o cache=none --syslog -o no_posix_lock -d" - args := q.virtiofsdArgs(123) - assert.Equal(strings.Join(args, " "), result) - - q.config.Debug = false - result = "--fd=123 -o source=test-share-dir/foo/shared -o cache=none --syslog -o no_posix_lock -f" - args = q.virtiofsdArgs(123) - assert.Equal(strings.Join(args, " "), result) -} - func TestQemuGetpids(t *testing.T) { assert := assert.New(t) diff --git a/src/runtime/virtcontainers/virtiofsd.go b/src/runtime/virtcontainers/virtiofsd.go index be2c0e069c..bd806782f2 100644 --- a/src/runtime/virtcontainers/virtiofsd.go +++ b/src/runtime/virtcontainers/virtiofsd.go @@ -9,14 +9,12 @@ import ( "bufio" "context" "fmt" - "io" "net" "os" "os/exec" "path/filepath" "strings" "syscall" - "time" "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/utils" "github.com/pkg/errors" @@ -26,20 +24,22 @@ import ( otelTrace "go.opentelemetry.io/otel/trace" ) -const ( - //Timeout to wait in secounds - virtiofsdStartTimeout = 5 +var ( + errVirtiofsdDaemonPathEmpty = errors.New("virtiofsd daemon path is empty") + errVirtiofsdSocketPathEmpty = errors.New("virtiofsd socket path is empty") + errVirtiofsdSourcePathEmpty = errors.New("virtiofsd source path is empty") + errVirtiofsdSourceNotAvailable = errors.New("virtiofsd source path not available") ) type Virtiofsd interface { // Start virtiofsd, return pid of virtiofsd process - Start(context.Context) (pid int, err error) + Start(context.Context, onQuitFunc) (pid int, err error) // Stop virtiofsd process Stop(context.Context) error } -// Helper function to check virtiofsd is serving -type virtiofsdWaitFunc func(runningCmd *exec.Cmd, stderr io.ReadCloser, debug bool) error +// Helper function to execute when virtiofsd quit +type onQuitFunc func() type virtiofsd struct { // path to virtiofsd daemon @@ -58,8 +58,6 @@ type virtiofsd struct { PID int // Neded by tracing ctx context.Context - // wait helper function to check if virtiofsd is serving - wait virtiofsdWaitFunc } // Open socket on behalf of virtiofsd @@ -85,7 +83,7 @@ func (v *virtiofsd) getSocketFD() (*os.File, error) { } // Start the virtiofsd daemon -func (v *virtiofsd) Start(ctx context.Context) (int, error) { +func (v *virtiofsd) Start(ctx context.Context, onQuit onQuitFunc) (int, error) { span, _ := v.trace(ctx, "Start") defer span.End() pid := 0 @@ -116,21 +114,29 @@ func (v *virtiofsd) Start(ctx context.Context) (int, error) { v.Logger().WithField("path", v.path).Info() v.Logger().WithField("args", strings.Join(args, " ")).Info() + stderr, err := cmd.StderrPipe() + if err != nil { + return pid, err + } if err = utils.StartCmd(cmd); err != nil { return pid, err } - defer func() { - if err != nil { - cmd.Process.Kill() + // Monitor virtiofsd's stderr and stop sandbox if virtiofsd quits + go func() { + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + v.Logger().WithField("source", "virtiofsd").Info(scanner.Text()) + } + v.Logger().Info("virtiofsd quits") + // Wait to release resources of virtiofsd process + cmd.Process.Wait() + if onQuit != nil { + onQuit() } }() - if v.wait == nil { - v.wait = waitVirtiofsReady - } - return cmd.Process.Pid, nil } @@ -139,10 +145,6 @@ func (v *virtiofsd) Stop(ctx context.Context) error { return nil } - if v.socketPath == "" { - return errors.New("vitiofsd socket path is empty") - } - err := os.Remove(v.socketPath) if err != nil { v.Logger().WithError(err).WithField("path", v.socketPath).Warn("removing virtiofsd socket failed") @@ -151,19 +153,10 @@ func (v *virtiofsd) Stop(ctx context.Context) error { } func (v *virtiofsd) args(FdSocketNumber uint) ([]string, error) { - if v.sourcePath == "" { - return []string{}, errors.New("vitiofsd source path is empty") - } - - if _, err := os.Stat(v.sourcePath); os.IsNotExist(err) { - return nil, err - } args := []string{ // Send logs to syslog "--syslog", - // foreground operation - "-f", // cache mode for virtiofsd "-o", "cache=" + v.cache, // disable posix locking in daemon: bunch of basic posix locks properties are broken @@ -176,7 +169,11 @@ func (v *virtiofsd) args(FdSocketNumber uint) ([]string, error) { } if v.debug { - args = append(args, "-o", "debug") + // enable debug output (implies -f) + args = append(args, "-d") + } else { + // foreground operation + args = append(args, "-f") } if len(v.extraArgs) != 0 { @@ -188,18 +185,20 @@ func (v *virtiofsd) args(FdSocketNumber uint) ([]string, error) { func (v *virtiofsd) valid() error { if v.path == "" { - errors.New("virtiofsd path is empty") + return errVirtiofsdDaemonPathEmpty } if v.socketPath == "" { - errors.New("Virtiofsd socket path is empty") + return errVirtiofsdSocketPathEmpty } if v.sourcePath == "" { - errors.New("virtiofsd source path is empty") - + return errVirtiofsdSourcePathEmpty } + if _, err := os.Stat(v.sourcePath); err != nil { + return errVirtiofsdSourceNotAvailable + } return nil } @@ -219,49 +218,6 @@ func (v *virtiofsd) trace(parent context.Context, name string) (otelTrace.Span, return span, ctx } -func waitVirtiofsReady(cmd *exec.Cmd, stderr io.ReadCloser, debug bool) error { - if cmd == nil { - return errors.New("cmd is nil") - } - - sockReady := make(chan error, 1) - go func() { - scanner := bufio.NewScanner(stderr) - var sent bool - for scanner.Scan() { - if debug { - virtLog.WithField("source", "virtiofsd").Debug(scanner.Text()) - } - if !sent && strings.Contains(scanner.Text(), "Waiting for vhost-user socket connection...") { - sockReady <- nil - sent = true - } - - } - if !sent { - if err := scanner.Err(); err != nil { - sockReady <- err - - } else { - sockReady <- fmt.Errorf("virtiofsd did not announce socket connection") - - } - - } - // Wait to release resources of virtiofsd process - cmd.Process.Wait() - }() - - var err error - select { - case err = <-sockReady: - case <-time.After(virtiofsdStartTimeout * time.Second): - err = fmt.Errorf("timed out waiting for vitiofsd ready mesage pid=%d", cmd.Process.Pid) - } - - return err -} - func (v *virtiofsd) kill(ctx context.Context) (err error) { span, _ := v.trace(ctx, "kill") defer span.End() @@ -283,7 +239,7 @@ type virtiofsdMock struct { } // Start the virtiofsd daemon -func (v *virtiofsdMock) Start(ctx context.Context) (int, error) { +func (v *virtiofsdMock) Start(ctx context.Context, onQuit onQuitFunc) (int, error) { return 9999999, nil } diff --git a/src/runtime/virtcontainers/virtiofsd_test.go b/src/runtime/virtcontainers/virtiofsd_test.go index 80d8a1670c..b9f67057c9 100644 --- a/src/runtime/virtcontainers/virtiofsd_test.go +++ b/src/runtime/virtcontainers/virtiofsd_test.go @@ -7,10 +7,9 @@ package virtcontainers import ( "context" - "io" "io/ioutil" "os" - "os/exec" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -67,13 +66,9 @@ func TestVirtiofsdStart(t *testing.T) { debug: tt.fields.debug, PID: tt.fields.PID, ctx: tt.fields.ctx, - //Mock wait function - wait: func(runningCmd *exec.Cmd, stderr io.ReadCloser, debug bool) error { - return nil - }, } var ctx context.Context - _, err := v.Start(ctx) + _, err := v.Start(ctx, nil) if (err != nil) != tt.wantErr { t.Errorf("virtiofsd.Start() error = %v, wantErr %v", err, tt.wantErr) return @@ -81,3 +76,71 @@ func TestVirtiofsdStart(t *testing.T) { }) } } + +func TestVirtiofsdArgs(t *testing.T) { + assert := assert.New(t) + + v := &virtiofsd{ + path: "/usr/bin/virtiofsd", + sourcePath: "/run/kata-shared/foo", + cache: "none", + } + + expected := "--syslog -o cache=none -o no_posix_lock -o source=/run/kata-shared/foo --fd=123 -f" + args, err := v.args(123) + assert.NoError(err) + assert.Equal(expected, strings.Join(args, " ")) + + v.debug = false + expected = "--syslog -o cache=none -o no_posix_lock -o source=/run/kata-shared/foo --fd=456 -f" + args, err = v.args(456) + assert.NoError(err) + assert.Equal(expected, strings.Join(args, " ")) +} + +func TestValid(t *testing.T) { + assert := assert.New(t) + + sourcePath, err := ioutil.TempDir("", "") + assert.NoError(err) + defer os.RemoveAll(sourcePath) + + socketDir, err := ioutil.TempDir("", "") + assert.NoError(err) + defer os.RemoveAll(socketDir) + + socketPath := socketDir + "socket.s" + + newVirtiofsdFunc := func() *virtiofsd { + return &virtiofsd{ + path: "/usr/bin/virtiofsd", + sourcePath: sourcePath, + socketPath: socketPath, + } + } + + // valid case + v := newVirtiofsdFunc() + err = v.valid() + assert.NoError(err) + + v = newVirtiofsdFunc() + v.path = "" + err = v.valid() + assert.Equal(errVirtiofsdDaemonPathEmpty, err) + + v = newVirtiofsdFunc() + v.sourcePath = "" + err = v.valid() + assert.Equal(errVirtiofsdSourcePathEmpty, err) + + v = newVirtiofsdFunc() + v.socketPath = "" + err = v.valid() + assert.Equal(errVirtiofsdSocketPathEmpty, err) + + v = newVirtiofsdFunc() + v.sourcePath = "/foo/bar" + err = v.valid() + assert.Equal(errVirtiofsdSourceNotAvailable, err) +}