diff --git a/cli/create.go b/cli/create.go index 8551e8b22e..9a09448591 100644 --- a/cli/create.go +++ b/cli/create.go @@ -273,6 +273,14 @@ func createSandbox(ctx context.Context, ociSpec oci.CompatOCISpec, runtimeConfig return vc.Process{}, err } + // Run pre-start OCI hooks. + err = enterNetNS(sandboxConfig.NetworkConfig.NetNSPath, func() error { + return preStartHooks(ctx, ociSpec, containerID, bundlePath) + }) + if err != nil { + return vc.Process{}, err + } + sandbox, err := vci.CreateSandbox(ctx, sandboxConfig) if err != nil { return vc.Process{}, err @@ -331,7 +339,15 @@ func createContainer(ctx context.Context, ociSpec oci.CompatOCISpec, containerID setExternalLoggers(ctx, kataLog) span.SetTag("sandbox", sandboxID) - _, c, err := vci.CreateContainer(ctx, sandboxID, contConfig) + s, c, err := vci.CreateContainer(ctx, sandboxID, contConfig) + if err != nil { + return vc.Process{}, err + } + + // Run pre-start OCI hooks. + err = enterNetNS(s.GetNetNs(), func() error { + return preStartHooks(ctx, ociSpec, containerID, bundlePath) + }) if err != nil { return vc.Process{}, err } diff --git a/cli/delete.go b/cli/delete.go index 4b359a899a..ff939f0648 100644 --- a/cli/delete.go +++ b/cli/delete.go @@ -12,6 +12,7 @@ import ( "os" vc "github.com/kata-containers/runtime/virtcontainers" + vcAnnot "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" "github.com/kata-containers/runtime/virtcontainers/pkg/oci" "github.com/sirupsen/logrus" "github.com/urfave/cli" @@ -117,6 +118,11 @@ func delete(ctx context.Context, containerID string, force bool) error { return fmt.Errorf("Invalid container type found") } + // Run post-stop OCI hooks. + if err := postStopHooks(ctx, ociSpec, sandboxID, status.Annotations[vcAnnot.BundlePathKey]); err != nil { + return err + } + // In order to prevent any file descriptor leak related to cgroups files // that have been previously created, we have to remove them before this // function returns. diff --git a/cli/hook.go b/cli/hook.go new file mode 100644 index 0000000000..b5f96fe741 --- /dev/null +++ b/cli/hook.go @@ -0,0 +1,138 @@ +// Copyright (c) 2018 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "strings" + "syscall" + "time" + + "github.com/kata-containers/runtime/virtcontainers/pkg/oci" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/opentracing/opentracing-go/log" + "github.com/sirupsen/logrus" +) + +// Logger returns a logrus logger appropriate for logging hook messages +func hookLogger() *logrus.Entry { + return kataLog.WithField("subsystem", "hook") +} + +func runHook(ctx context.Context, hook specs.Hook, cid, bundlePath string) error { + span, _ := trace(ctx, "hook") + defer span.Finish() + + span.SetTag("subsystem", "runHook") + + span.LogFields( + log.String("hook-name", hook.Path), + log.String("hook-args", strings.Join(hook.Args, " "))) + + state := specs.State{ + Pid: os.Getpid(), + Bundle: bundlePath, + ID: cid, + } + + stateJSON, err := json.Marshal(state) + if err != nil { + return err + } + + var stdout, stderr bytes.Buffer + cmd := &exec.Cmd{ + Path: hook.Path, + Args: hook.Args, + Env: hook.Env, + Stdin: bytes.NewReader(stateJSON), + Stdout: &stdout, + Stderr: &stderr, + } + + if err := cmd.Start(); err != nil { + return err + } + + if hook.Timeout == nil { + if err := cmd.Wait(); err != nil { + return fmt.Errorf("%s: stdout: %s, stderr: %s", err, stdout.String(), stderr.String()) + } + } else { + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + close(done) + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("%s: stdout: %s, stderr: %s", err, stdout.String(), stderr.String()) + } + case <-time.After(time.Duration(*hook.Timeout) * time.Second): + if err := syscall.Kill(cmd.Process.Pid, syscall.SIGKILL); err != nil { + return err + } + + return fmt.Errorf("Hook timeout") + } + } + + return nil +} + +func runHooks(ctx context.Context, hooks []specs.Hook, cid, bundlePath, hookType string) error { + span, _ := trace(ctx, "hooks") + defer span.Finish() + + span.SetTag("subsystem", hookType) + + for _, hook := range hooks { + if err := runHook(ctx, hook, cid, bundlePath); err != nil { + hookLogger().WithFields(logrus.Fields{ + "hook-type": hookType, + "error": err, + }).Error("hook error") + + return err + } + } + + return nil +} + +func preStartHooks(ctx context.Context, spec oci.CompatOCISpec, cid, bundlePath string) error { + // If no hook available, nothing needs to be done. + if spec.Hooks == nil { + return nil + } + + return runHooks(ctx, spec.Hooks.Prestart, cid, bundlePath, "pre-start") +} + +func postStartHooks(ctx context.Context, spec oci.CompatOCISpec, cid, bundlePath string) error { + // If no hook available, nothing needs to be done. + if spec.Hooks == nil { + return nil + } + + return runHooks(ctx, spec.Hooks.Poststart, cid, bundlePath, "post-start") +} + +func postStopHooks(ctx context.Context, spec oci.CompatOCISpec, cid, bundlePath string) error { + // If no hook available, nothing needs to be done. + if spec.Hooks == nil { + return nil + } + + return runHooks(ctx, spec.Hooks.Poststop, cid, bundlePath, "post-stop") +} diff --git a/cli/hook_test.go b/cli/hook_test.go new file mode 100644 index 0000000000..0f3fed8c5c --- /dev/null +++ b/cli/hook_test.go @@ -0,0 +1,229 @@ +// Copyright (c) 2018 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +package main + +import ( + "context" + "os" + "testing" + + . "github.com/kata-containers/runtime/virtcontainers/pkg/mock" + "github.com/kata-containers/runtime/virtcontainers/pkg/oci" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/assert" +) + +// Important to keep these values in sync with hook test binary +var testKeyHook = "test-key" +var testContainerIDHook = "test-container-id" +var testControllerIDHook = "test-controller-id" +var testBinHookPath = "/usr/bin/virtcontainers/bin/test/hook" +var testBundlePath = "/test/bundle" + +func getMockHookBinPath() string { + if DefaultMockHookBinPath == "" { + return testBinHookPath + } + + return DefaultMockHookBinPath +} + +func createHook(timeout int) specs.Hook { + to := &timeout + if timeout == 0 { + to = nil + } + + return specs.Hook{ + Path: getMockHookBinPath(), + Args: []string{testKeyHook, testContainerIDHook, testControllerIDHook}, + Env: os.Environ(), + Timeout: to, + } +} + +func createWrongHook() specs.Hook { + return specs.Hook{ + Path: getMockHookBinPath(), + Args: []string{"wrong-args"}, + Env: os.Environ(), + } +} + +func TestRunHook(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(testDisabledNeedNonRoot) + } + + assert := assert.New(t) + + ctx := context.Background() + + // Run with timeout 0 + hook := createHook(0) + err := runHook(ctx, hook, testSandboxID, testBundlePath) + assert.NoError(err) + + // Run with timeout 1 + hook = createHook(1) + err = runHook(ctx, hook, testSandboxID, testBundlePath) + assert.NoError(err) + + // Run timeout failure + hook = createHook(1) + hook.Args = append(hook.Args, "2") + err = runHook(ctx, hook, testSandboxID, testBundlePath) + assert.Error(err) + + // Failure due to wrong hook + hook = createWrongHook() + err = runHook(ctx, hook, testSandboxID, testBundlePath) + assert.Error(err) +} + +func TestPreStartHooks(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(testDisabledNeedNonRoot) + } + + assert := assert.New(t) + + ctx := context.Background() + + // Hooks field is nil + spec := oci.CompatOCISpec{} + err := preStartHooks(ctx, spec, "", "") + assert.NoError(err) + + // Hooks list is empty + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{}, + }, + } + err = preStartHooks(ctx, spec, "", "") + assert.NoError(err) + + // Run with timeout 0 + hook := createHook(0) + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{hook}, + }, + }, + } + err = preStartHooks(ctx, spec, testSandboxID, testBundlePath) + assert.NoError(err) + + // Failure due to wrong hook + hook = createWrongHook() + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{hook}, + }, + }, + } + err = preStartHooks(ctx, spec, testSandboxID, testBundlePath) + assert.Error(err) +} + +func TestPostStartHooks(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(testDisabledNeedNonRoot) + } + + assert := assert.New(t) + + ctx := context.Background() + + // Hooks field is nil + spec := oci.CompatOCISpec{} + err := postStartHooks(ctx, spec, "", "") + assert.NoError(err) + + // Hooks list is empty + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{}, + }, + } + err = postStartHooks(ctx, spec, "", "") + assert.NoError(err) + + // Run with timeout 0 + hook := createHook(0) + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{ + Poststart: []specs.Hook{hook}, + }, + }, + } + err = postStartHooks(ctx, spec, testSandboxID, testBundlePath) + assert.NoError(err) + + // Failure due to wrong hook + hook = createWrongHook() + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{ + Poststart: []specs.Hook{hook}, + }, + }, + } + err = postStartHooks(ctx, spec, testSandboxID, testBundlePath) + assert.Error(err) +} + +func TestPostStopHooks(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(testDisabledNeedNonRoot) + } + + assert := assert.New(t) + + ctx := context.Background() + + // Hooks field is nil + spec := oci.CompatOCISpec{} + err := postStopHooks(ctx, spec, "", "") + assert.NoError(err) + + // Hooks list is empty + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{}, + }, + } + err = postStopHooks(ctx, spec, "", "") + assert.NoError(err) + + // Run with timeout 0 + hook := createHook(0) + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{ + Poststop: []specs.Hook{hook}, + }, + }, + } + err = postStopHooks(ctx, spec, testSandboxID, testBundlePath) + assert.NoError(err) + + // Failure due to wrong hook + hook = createWrongHook() + spec = oci.CompatOCISpec{ + Spec: specs.Spec{ + Hooks: &specs.Hooks{ + Poststop: []specs.Hook{hook}, + }, + }, + } + err = postStopHooks(ctx, spec, testSandboxID, testBundlePath) + assert.Error(err) +} diff --git a/cli/oci.go b/cli/oci.go index cab765fbfc..928ab3e6e7 100644 --- a/cli/oci.go +++ b/cli/oci.go @@ -14,9 +14,11 @@ import ( "net" "os" "path/filepath" + goruntime "runtime" "strings" "syscall" + "github.com/containernetworking/plugins/pkg/ns" vc "github.com/kata-containers/runtime/virtcontainers" "github.com/kata-containers/runtime/virtcontainers/pkg/oci" "github.com/opencontainers/runc/libcontainer/utils" @@ -414,3 +416,33 @@ func delContainerIDMapping(ctx context.Context, containerID string) error { return os.RemoveAll(path) } + +// enterNetNS is free from any call to a go routine, and it calls +// into runtime.LockOSThread(), meaning it won't be executed in a +// different thread than the one expected by the caller. +func enterNetNS(netNSPath string, cb func() error) error { + if netNSPath == "" { + return cb() + } + + goruntime.LockOSThread() + defer goruntime.UnlockOSThread() + + currentNS, err := ns.GetCurrentNS() + if err != nil { + return err + } + defer currentNS.Close() + + targetNS, err := ns.GetNS(netNSPath) + if err != nil { + return err + } + + if err := targetNS.Set(); err != nil { + return err + } + defer currentNS.Set() + + return cb() +} diff --git a/cli/start.go b/cli/start.go index b39ebbb2ba..d8e20664d7 100644 --- a/cli/start.go +++ b/cli/start.go @@ -11,6 +11,7 @@ import ( "fmt" vc "github.com/kata-containers/runtime/virtcontainers" + vcAnnot "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" "github.com/kata-containers/runtime/virtcontainers/pkg/oci" "github.com/sirupsen/logrus" "github.com/urfave/cli" @@ -76,14 +77,36 @@ func start(ctx context.Context, containerID string) (vc.VCSandbox, error) { return nil, err } - if containerType.IsSandbox() { - return vci.StartSandbox(ctx, sandboxID) - } - - c, err := vci.StartContainer(ctx, sandboxID, containerID) + ociSpec, err := oci.GetOCIConfig(status) if err != nil { return nil, err } - return c.Sandbox(), nil + var sandbox vc.VCSandbox + + if containerType.IsSandbox() { + s, err := vci.StartSandbox(ctx, sandboxID) + if err != nil { + return nil, err + } + + sandbox = s + } else { + c, err := vci.StartContainer(ctx, sandboxID, containerID) + if err != nil { + return nil, err + } + + sandbox = c.Sandbox() + } + + // Run post-start OCI hooks. + err = enterNetNS(sandbox.GetNetNs(), func() error { + return postStartHooks(ctx, ociSpec, sandboxID, status.Annotations[vcAnnot.BundlePathKey]) + }) + if err != nil { + return nil, err + } + + return sandbox, nil } diff --git a/cli/start_test.go b/cli/start_test.go index f3fb87a59f..5aca113621 100644 --- a/cli/start_test.go +++ b/cli/start_test.go @@ -7,6 +7,7 @@ package main import ( "context" + "encoding/json" "flag" "io/ioutil" "os" @@ -14,6 +15,7 @@ import ( vc "github.com/kata-containers/runtime/virtcontainers" vcAnnotations "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" + "github.com/kata-containers/runtime/virtcontainers/pkg/oci" "github.com/kata-containers/runtime/virtcontainers/pkg/vcmock" "github.com/stretchr/testify/assert" "github.com/urfave/cli" @@ -58,11 +60,15 @@ func TestStartSandbox(t *testing.T) { assert.NoError(err) defer os.RemoveAll(path) + ociSpecJSON, err := json.Marshal(oci.CompatOCISpec{}) + assert.NoError(err) + testingImpl.StatusContainerFunc = func(ctx context.Context, sandboxID, containerID string) (vc.ContainerStatus, error) { return vc.ContainerStatus{ ID: sandbox.ID(), Annotations: map[string]string{ vcAnnotations.ContainerTypeKey: string(vc.PodSandbox), + vcAnnotations.ConfigJSONKey: string(ociSpecJSON), }, }, nil } @@ -132,11 +138,15 @@ func TestStartContainerSucessFailure(t *testing.T) { assert.NoError(err) defer os.RemoveAll(path) + ociSpecJSON, err := json.Marshal(oci.CompatOCISpec{}) + assert.NoError(err) + testingImpl.StatusContainerFunc = func(ctx context.Context, sandboxID, containerID string) (vc.ContainerStatus, error) { return vc.ContainerStatus{ ID: testContainerID, Annotations: map[string]string{ vcAnnotations.ContainerTypeKey: string(vc.PodContainer), + vcAnnotations.ConfigJSONKey: string(ociSpecJSON), }, }, nil } @@ -206,11 +216,15 @@ func TestStartCLIFunctionSuccess(t *testing.T) { assert.NoError(err) defer os.RemoveAll(path) + ociSpecJSON, err := json.Marshal(oci.CompatOCISpec{}) + assert.NoError(err) + testingImpl.StatusContainerFunc = func(ctx context.Context, sandboxID, containerID string) (vc.ContainerStatus, error) { return vc.ContainerStatus{ ID: testContainerID, Annotations: map[string]string{ vcAnnotations.ContainerTypeKey: string(vc.PodContainer), + vcAnnotations.ConfigJSONKey: string(ociSpecJSON), }, }, nil } diff --git a/virtcontainers/api.go b/virtcontainers/api.go index 82ebfe15d7..703ecb5868 100644 --- a/virtcontainers/api.go +++ b/virtcontainers/api.go @@ -40,14 +40,6 @@ func trace(parent context.Context, name string) (opentracing.Span, context.Conte return span, ctx } -func traceWithSubsys(ctx context.Context, subsys, name string) (opentracing.Span, context.Context) { - span, ctx := opentracing.StartSpanFromContext(ctx, name) - - span.SetTag("subsystem", subsys) - - return span, ctx -} - // SetLogger sets the logger for virtcontainers package. func SetLogger(ctx context.Context, logger *logrus.Entry) { fields := virtLog.Data @@ -236,11 +228,6 @@ func startSandbox(s *Sandbox) (*Sandbox, error) { return nil, err } - // Execute poststart hooks. - if err := s.config.Hooks.postStartHooks(s); err != nil { - return nil, err - } - return s, nil } @@ -278,11 +265,6 @@ func StopSandbox(ctx context.Context, sandboxID string) (VCSandbox, error) { return nil, err } - // Execute poststop hooks. - if err := s.config.Hooks.postStopHooks(s); err != nil { - return nil, err - } - return s, nil } diff --git a/virtcontainers/api_test.go b/virtcontainers/api_test.go index 65478430d1..11fdd253cb 100644 --- a/virtcontainers/api_test.go +++ b/virtcontainers/api_test.go @@ -143,22 +143,10 @@ func newTestSandboxConfigHyperstartAgentDefaultNetwork() SandboxConfig { SockTtyName: testHyperstartTtySocket, } - hooks := Hooks{ - PreStartHooks: []Hook{ - { - Path: getMockHookBinPath(), - Args: []string{testKeyHook, testContainerIDHook, testControllerIDHook}, - }, - }, - PostStartHooks: []Hook{}, - PostStopHooks: []Hook{}, - } - netConfig := NetworkConfig{} sandboxConfig := SandboxConfig{ - ID: testSandboxID, - Hooks: hooks, + ID: testSandboxID, HypervisorType: MockHypervisor, HypervisorConfig: hypervisorConfig, diff --git a/virtcontainers/hook.go b/virtcontainers/hook.go deleted file mode 100644 index fb751fd85c..0000000000 --- a/virtcontainers/hook.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2017 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 -// - -package virtcontainers - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "os" - "os/exec" - "strings" - "syscall" - "time" - - vcAnnotations "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" - specs "github.com/opencontainers/runtime-spec/specs-go" - opentracing "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/log" - "github.com/sirupsen/logrus" -) - -// Hook represents an OCI hook, including its required parameters. -type Hook struct { - Path string - Args []string - Env []string - Timeout int -} - -// Hooks gathers all existing OCI hooks list. -type Hooks struct { - PreStartHooks []Hook - PostStartHooks []Hook - PostStopHooks []Hook -} - -// Logger returns a logrus logger appropriate for logging Hooks messages -func (h *Hooks) Logger() *logrus.Entry { - return virtLog.WithField("subsystem", "hooks") -} - -func buildHookState(processID int, s *Sandbox) specs.State { - annotations := s.GetAnnotations() - return specs.State{ - Pid: processID, - Bundle: annotations[vcAnnotations.BundlePathKey], - ID: s.id, - } -} - -func (h *Hook) trace(ctx context.Context, name string) (opentracing.Span, context.Context) { - return traceWithSubsys(ctx, "hook", name) -} - -func (h *Hooks) trace(ctx context.Context, name string) (opentracing.Span, context.Context) { - return traceWithSubsys(ctx, "hooks", name) -} - -func (h *Hook) runHook(s *Sandbox) error { - span, _ := h.trace(s.ctx, "runHook") - defer span.Finish() - - span.LogFields( - log.String("hook-name", h.Path), - log.String("hook-args", strings.Join(h.Args, " "))) - - state := buildHookState(os.Getpid(), s) - stateJSON, err := json.Marshal(state) - if err != nil { - return err - } - - var stdout, stderr bytes.Buffer - cmd := &exec.Cmd{ - Path: h.Path, - Args: h.Args, - Env: h.Env, - Stdin: bytes.NewReader(stateJSON), - Stdout: &stdout, - Stderr: &stderr, - } - - err = cmd.Start() - if err != nil { - return err - } - - if h.Timeout == 0 { - err = cmd.Wait() - if err != nil { - return fmt.Errorf("%s: stdout: %s, stderr: %s", err, stdout.String(), stderr.String()) - } - } else { - done := make(chan error, 1) - go func() { - done <- cmd.Wait() - close(done) - }() - - select { - case err := <-done: - if err != nil { - return fmt.Errorf("%s: stdout: %s, stderr: %s", err, stdout.String(), stderr.String()) - } - case <-time.After(time.Duration(h.Timeout) * time.Second): - if err := syscall.Kill(cmd.Process.Pid, syscall.SIGKILL); err != nil { - return err - } - - return fmt.Errorf("Hook timeout") - } - } - - return nil -} - -func (h *Hooks) preStartHooks(s *Sandbox) error { - span, _ := h.trace(s.ctx, "preStartHooks") - defer span.Finish() - - if len(h.PreStartHooks) == 0 { - return nil - } - - for _, hook := range h.PreStartHooks { - err := hook.runHook(s) - if err != nil { - h.Logger().WithFields(logrus.Fields{ - "hook-type": "pre-start", - "error": err, - }).Error("hook error") - - return err - } - } - - return nil -} - -func (h *Hooks) postStartHooks(s *Sandbox) error { - span, _ := h.trace(s.ctx, "postStartHooks") - defer span.Finish() - - if len(h.PostStartHooks) == 0 { - return nil - } - - for _, hook := range h.PostStartHooks { - err := hook.runHook(s) - if err != nil { - // In case of post start hook, the error is not fatal, - // just need to be logged. - h.Logger().WithFields(logrus.Fields{ - "hook-type": "post-start", - "error": err, - }).Info("hook error") - } - } - - return nil -} - -func (h *Hooks) postStopHooks(s *Sandbox) error { - span, _ := h.trace(s.ctx, "postStopHooks") - defer span.Finish() - - if len(h.PostStopHooks) == 0 { - return nil - } - - for _, hook := range h.PostStopHooks { - err := hook.runHook(s) - if err != nil { - // In case of post stop hook, the error is not fatal, - // just need to be logged. - h.Logger().WithFields(logrus.Fields{ - "hook-type": "post-stop", - "error": err, - }).Info("hook error") - } - } - - return nil -} diff --git a/virtcontainers/hook_test.go b/virtcontainers/hook_test.go deleted file mode 100644 index cdf41339c3..0000000000 --- a/virtcontainers/hook_test.go +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright (c) 2017 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 -// - -package virtcontainers - -import ( - "context" - "io/ioutil" - "os" - "path/filepath" - "reflect" - "sync" - "testing" - - vcAnnotations "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" - . "github.com/kata-containers/runtime/virtcontainers/pkg/mock" - specs "github.com/opencontainers/runtime-spec/specs-go" -) - -// Important to keep these values in sync with hook test binary -var testKeyHook = "test-key" -var testContainerIDHook = "test-container-id" -var testControllerIDHook = "test-controller-id" -var testProcessIDHook = 12345 -var testBinHookPath = "/usr/bin/virtcontainers/bin/test/hook" -var testBundlePath = "/test/bundle" - -func getMockHookBinPath() string { - if DefaultMockHookBinPath == "" { - return testBinHookPath - } - - return DefaultMockHookBinPath -} - -func TestBuildHookState(t *testing.T) { - t.Skip() - expected := specs.State{ - Pid: testProcessIDHook, - } - - s := &Sandbox{} - - hookState := buildHookState(testProcessIDHook, s) - - if reflect.DeepEqual(hookState, expected) == false { - t.Fatal() - } - - s = createTestSandbox() - hookState = buildHookState(testProcessIDHook, s) - - expected = specs.State{ - Pid: testProcessIDHook, - Bundle: testBundlePath, - ID: testSandboxID, - } - - if reflect.DeepEqual(hookState, expected) == false { - t.Fatal() - } - -} - -func createHook(timeout int) *Hook { - return &Hook{ - Path: getMockHookBinPath(), - Args: []string{testKeyHook, testContainerIDHook, testControllerIDHook}, - Env: os.Environ(), - Timeout: timeout, - } -} - -func createWrongHook() *Hook { - return &Hook{ - Path: getMockHookBinPath(), - Args: []string{"wrong-args"}, - Env: os.Environ(), - } -} - -func createTestSandbox() *Sandbox { - c := &SandboxConfig{ - Annotations: map[string]string{ - vcAnnotations.BundlePathKey: testBundlePath, - }, - } - return &Sandbox{ - annotationsLock: &sync.RWMutex{}, - config: c, - id: testSandboxID, - ctx: context.Background(), - } -} - -func testRunHookFull(t *testing.T, timeout int, expectFail bool) { - hook := createHook(timeout) - - s := createTestSandbox() - err := hook.runHook(s) - if expectFail { - if err == nil { - t.Fatal("unexpected success") - } - } else { - if err != nil { - t.Fatalf("unexpected failure: %v", err) - } - } -} - -func testRunHook(t *testing.T, timeout int) { - testRunHookFull(t, timeout, false) -} - -func TestRunHook(t *testing.T) { - cleanUp() - - testRunHook(t, 0) -} - -func TestRunHookTimeout(t *testing.T) { - testRunHook(t, 1) -} - -func TestRunHookExitFailure(t *testing.T) { - hook := createWrongHook() - s := createTestSandbox() - - err := hook.runHook(s) - if err == nil { - t.Fatal() - } -} - -func TestRunHookTimeoutFailure(t *testing.T) { - hook := createHook(1) - - hook.Args = append(hook.Args, "2") - - s := createTestSandbox() - - err := hook.runHook(s) - if err == nil { - t.Fatal() - } -} - -func TestRunHookWaitFailure(t *testing.T) { - hook := createHook(60) - - hook.Args = append(hook.Args, "1", "panic") - s := createTestSandbox() - - err := hook.runHook(s) - if err == nil { - t.Fatal() - } -} - -func testRunHookInvalidCommand(t *testing.T, timeout int) { - cleanUp() - - dir, err := ioutil.TempDir("", "") - if err != nil { - t.Fatal(err) - } - - defer os.RemoveAll(dir) - - cmd := filepath.Join(dir, "does-not-exist") - - savedDefaultMockHookBinPath := DefaultMockHookBinPath - DefaultMockHookBinPath = cmd - - defer func() { - DefaultMockHookBinPath = savedDefaultMockHookBinPath - }() - - testRunHookFull(t, timeout, true) -} - -func TestRunHookInvalidCommand(t *testing.T) { - testRunHookInvalidCommand(t, 0) -} - -func TestRunHookTimeoutInvalidCommand(t *testing.T) { - testRunHookInvalidCommand(t, 1) -} - -func testHooks(t *testing.T, hook *Hook) { - hooks := &Hooks{ - PreStartHooks: []Hook{*hook}, - PostStartHooks: []Hook{*hook}, - PostStopHooks: []Hook{*hook}, - } - s := createTestSandbox() - - err := hooks.preStartHooks(s) - if err != nil { - t.Fatal(err) - } - - err = hooks.postStartHooks(s) - if err != nil { - t.Fatal(err) - } - - err = hooks.postStopHooks(s) - if err != nil { - t.Fatal(err) - } -} - -func testFailingHooks(t *testing.T, hook *Hook) { - hooks := &Hooks{ - PreStartHooks: []Hook{*hook}, - PostStartHooks: []Hook{*hook}, - PostStopHooks: []Hook{*hook}, - } - s := createTestSandbox() - - err := hooks.preStartHooks(s) - if err == nil { - t.Fatal(err) - } - - err = hooks.postStartHooks(s) - if err != nil { - t.Fatal(err) - } - - err = hooks.postStopHooks(s) - if err != nil { - t.Fatal(err) - } -} - -func TestHooks(t *testing.T) { - testHooks(t, createHook(0)) -} - -func TestHooksTimeout(t *testing.T) { - testHooks(t, createHook(1)) -} - -func TestFailingHooks(t *testing.T) { - testFailingHooks(t, createWrongHook()) -} - -func TestEmptyHooks(t *testing.T) { - hooks := &Hooks{} - s := createTestSandbox() - - err := hooks.preStartHooks(s) - if err != nil { - t.Fatal(err) - } - - err = hooks.postStartHooks(s) - if err != nil { - t.Fatal(err) - } - - err = hooks.postStopHooks(s) - if err != nil { - t.Fatal(err) - } -} diff --git a/virtcontainers/pkg/oci/utils.go b/virtcontainers/pkg/oci/utils.go index 573f517507..a5bd5a57ea 100644 --- a/virtcontainers/pkg/oci/utils.go +++ b/virtcontainers/pkg/oci/utils.go @@ -157,43 +157,6 @@ func cmdEnvs(spec CompatOCISpec, envs []vc.EnvVar) []vc.EnvVar { return envs } -func newHook(h spec.Hook) vc.Hook { - timeout := 0 - if h.Timeout != nil { - timeout = *h.Timeout - } - - return vc.Hook{ - Path: h.Path, - Args: h.Args, - Env: h.Env, - Timeout: timeout, - } -} - -func containerHooks(spec CompatOCISpec) vc.Hooks { - ociHooks := spec.Hooks - if ociHooks == nil { - return vc.Hooks{} - } - - var hooks vc.Hooks - - for _, h := range ociHooks.Prestart { - hooks.PreStartHooks = append(hooks.PreStartHooks, newHook(h)) - } - - for _, h := range ociHooks.Poststart { - hooks.PostStartHooks = append(hooks.PostStartHooks, newHook(h)) - } - - for _, h := range ociHooks.Poststop { - hooks.PostStopHooks = append(hooks.PostStopHooks, newHook(h)) - } - - return hooks -} - func newMount(m spec.Mount) vc.Mount { return vc.Mount{ Source: m.Source, @@ -517,8 +480,6 @@ func SandboxConfig(ocispec CompatOCISpec, runtime RuntimeConfig, bundlePath, cid Hostname: ocispec.Hostname, - Hooks: containerHooks(ocispec), - VMConfig: resources, HypervisorType: runtime.HypervisorType, diff --git a/virtcontainers/sandbox.go b/virtcontainers/sandbox.go index 3fb0fd217d..e39cfe3eee 100644 --- a/virtcontainers/sandbox.go +++ b/virtcontainers/sandbox.go @@ -319,9 +319,6 @@ type SandboxConfig struct { Hostname string - // Field specific to OCI specs, needed to setup all the hooks - Hooks Hooks - // VMConfig is the VM configuration to set for this sandbox. VMConfig Resources @@ -976,13 +973,6 @@ func (s *Sandbox) createNetwork() error { span, _ := s.trace("createNetwork") defer span.Finish() - // Execute prestart hooks inside netns - if err := s.network.run(s.config.NetworkConfig.NetNSPath, func() error { - return s.config.Hooks.preStartHooks(s) - }); err != nil { - return err - } - // Add the network if err := s.network.add(s); err != nil { return err