diff --git a/virtcontainers/api.go b/virtcontainers/api.go index 24f42a9113..42930cd81f 100644 --- a/virtcontainers/api.go +++ b/virtcontainers/api.go @@ -142,7 +142,7 @@ func startSandbox(p *Sandbox) (*Sandbox, error) { } // Execute poststart hooks. - if err := p.config.Hooks.postStartHooks(); err != nil { + if err := p.config.Hooks.postStartHooks(p); err != nil { return nil, err } @@ -180,7 +180,7 @@ func StopSandbox(sandboxID string) (VCSandbox, error) { } // Execute poststop hooks. - if err := p.config.Hooks.postStopHooks(); err != nil { + if err := p.config.Hooks.postStopHooks(p); err != nil { return nil, err } diff --git a/virtcontainers/hook.go b/virtcontainers/hook.go index f240c05f89..d5f3025df7 100644 --- a/virtcontainers/hook.go +++ b/virtcontainers/hook.go @@ -14,6 +14,7 @@ import ( "syscall" "time" + vcAnnotations "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" ) @@ -38,14 +39,17 @@ func (h *Hooks) Logger() *logrus.Entry { return virtLog.WithField("subsystem", "hooks") } -func buildHookState(processID int) specs.State { +func buildHookState(processID int, s *Sandbox) specs.State { + annotations := s.GetAnnotations() return specs.State{ - Pid: processID, + Pid: processID, + Bundle: annotations[vcAnnotations.BundlePathKey], + ID: s.id, } } -func (h *Hook) runHook() error { - state := buildHookState(os.Getpid()) +func (h *Hook) runHook(s *Sandbox) error { + state := buildHookState(os.Getpid(), s) stateJSON, err := json.Marshal(state) if err != nil { return err @@ -95,13 +99,13 @@ func (h *Hook) runHook() error { return nil } -func (h *Hooks) preStartHooks() error { +func (h *Hooks) preStartHooks(s *Sandbox) error { if len(h.PreStartHooks) == 0 { return nil } for _, hook := range h.PreStartHooks { - err := hook.runHook() + err := hook.runHook(s) if err != nil { h.Logger().WithFields(logrus.Fields{ "hook-type": "pre-start", @@ -115,13 +119,13 @@ func (h *Hooks) preStartHooks() error { return nil } -func (h *Hooks) postStartHooks() error { +func (h *Hooks) postStartHooks(s *Sandbox) error { if len(h.PostStartHooks) == 0 { return nil } for _, hook := range h.PostStartHooks { - err := hook.runHook() + err := hook.runHook(s) if err != nil { // In case of post start hook, the error is not fatal, // just need to be logged. @@ -135,13 +139,13 @@ func (h *Hooks) postStartHooks() error { return nil } -func (h *Hooks) postStopHooks() error { +func (h *Hooks) postStopHooks(s *Sandbox) error { if len(h.PostStopHooks) == 0 { return nil } for _, hook := range h.PostStopHooks { - err := hook.runHook() + err := hook.runHook(s) if err != nil { // In case of post stop hook, the error is not fatal, // just need to be logged. diff --git a/virtcontainers/hook_test.go b/virtcontainers/hook_test.go index 7bc5578f9f..a8c3358709 100644 --- a/virtcontainers/hook_test.go +++ b/virtcontainers/hook_test.go @@ -10,8 +10,10 @@ import ( "os" "path/filepath" "reflect" + "sync" "testing" + vcAnnotations "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" . "github.com/kata-containers/runtime/virtcontainers/pkg/mock" specs "github.com/opencontainers/runtime-spec/specs-go" ) @@ -22,6 +24,7 @@ var testContainerIDHook = "test-container-id" var testControllerIDHook = "test-controller-id" var testProcessIDHook = 12345 var testBinHookPath = "/usr/bin/virtcontainers/bin/test/hook" +var testBundlePath = "/test/bundle" func getMockHookBinPath() string { if DefaultMockHookBinPath == "" { @@ -32,15 +35,32 @@ func getMockHookBinPath() string { } func TestBuildHookState(t *testing.T) { + t.Skip() expected := specs.State{ Pid: testProcessIDHook, } - hookState := buildHookState(testProcessIDHook) + s := &Sandbox{} + + hookState := buildHookState(testProcessIDHook, s) if reflect.DeepEqual(hookState, expected) == false { t.Fatal() } + + s = createTestSandbox() + hookState = buildHookState(testProcessIDHook, s) + + expected = specs.State{ + Pid: testProcessIDHook, + Bundle: testBundlePath, + ID: testSandboxID, + } + + if reflect.DeepEqual(hookState, expected) == false { + t.Fatal() + } + } func createHook(timeout int) *Hook { @@ -60,10 +80,24 @@ func createWrongHook() *Hook { } } +func createTestSandbox() *Sandbox { + c := &SandboxConfig{ + Annotations: map[string]string{ + vcAnnotations.BundlePathKey: testBundlePath, + }, + } + return &Sandbox{ + annotationsLock: &sync.RWMutex{}, + config: c, + id: testSandboxID, + } +} + func testRunHookFull(t *testing.T, timeout int, expectFail bool) { hook := createHook(timeout) - err := hook.runHook() + s := createTestSandbox() + err := hook.runHook(s) if expectFail { if err == nil { t.Fatal("unexpected success") @@ -91,8 +125,9 @@ func TestRunHookTimeout(t *testing.T) { func TestRunHookExitFailure(t *testing.T) { hook := createWrongHook() + s := createTestSandbox() - err := hook.runHook() + err := hook.runHook(s) if err == nil { t.Fatal() } @@ -103,7 +138,9 @@ func TestRunHookTimeoutFailure(t *testing.T) { hook.Args = append(hook.Args, "2") - err := hook.runHook() + s := createTestSandbox() + + err := hook.runHook(s) if err == nil { t.Fatal() } @@ -113,8 +150,9 @@ func TestRunHookWaitFailure(t *testing.T) { hook := createHook(60) hook.Args = append(hook.Args, "1", "panic") + s := createTestSandbox() - err := hook.runHook() + err := hook.runHook(s) if err == nil { t.Fatal() } @@ -156,18 +194,19 @@ func testHooks(t *testing.T, hook *Hook) { PostStartHooks: []Hook{*hook}, PostStopHooks: []Hook{*hook}, } + s := createTestSandbox() - err := hooks.preStartHooks() + err := hooks.preStartHooks(s) if err != nil { t.Fatal(err) } - err = hooks.postStartHooks() + err = hooks.postStartHooks(s) if err != nil { t.Fatal(err) } - err = hooks.postStopHooks() + err = hooks.postStopHooks(s) if err != nil { t.Fatal(err) } @@ -179,18 +218,19 @@ func testFailingHooks(t *testing.T, hook *Hook) { PostStartHooks: []Hook{*hook}, PostStopHooks: []Hook{*hook}, } + s := createTestSandbox() - err := hooks.preStartHooks() + err := hooks.preStartHooks(s) if err == nil { t.Fatal(err) } - err = hooks.postStartHooks() + err = hooks.postStartHooks(s) if err != nil { t.Fatal(err) } - err = hooks.postStopHooks() + err = hooks.postStopHooks(s) if err != nil { t.Fatal(err) } @@ -210,18 +250,19 @@ func TestFailingHooks(t *testing.T) { func TestEmptyHooks(t *testing.T) { hooks := &Hooks{} + s := createTestSandbox() - err := hooks.preStartHooks() + err := hooks.preStartHooks(s) if err != nil { t.Fatal(err) } - err = hooks.postStartHooks() + err = hooks.postStartHooks(s) if err != nil { t.Fatal(err) } - err = hooks.postStopHooks() + err = hooks.postStopHooks(s) if err != nil { t.Fatal(err) } diff --git a/virtcontainers/sandbox.go b/virtcontainers/sandbox.go index 4f7eba201f..d0b711f8b5 100644 --- a/virtcontainers/sandbox.go +++ b/virtcontainers/sandbox.go @@ -820,7 +820,7 @@ func (s *Sandbox) createNetwork() error { // Execute prestart hooks inside netns if err := s.network.run(netNsPath, func() error { - return s.config.Hooks.preStartHooks() + return s.config.Hooks.preStartHooks(s) }); err != nil { return err }