diff --git a/pkg/ns/ns.go b/pkg/ns/ns.go index 119a8ce8..837ab8be 100644 --- a/pkg/ns/ns.go +++ b/pkg/ns/ns.go @@ -58,6 +58,7 @@ type NetNS interface { type netNS struct { file *os.File mounted bool + closed bool } func getCurrentThreadNetNSPath() string { @@ -165,8 +166,22 @@ func (ns *netNS) Fd() uintptr { return ns.file.Fd() } +func (ns *netNS) errorIfClosed() error { + if ns.closed { + return fmt.Errorf("%q has already been closed", ns.file.Name()) + } + return nil +} + func (ns *netNS) Close() error { - ns.file.Close() + if err := ns.errorIfClosed(); err != nil { + return err + } + + if err := ns.file.Close(); err != nil { + return fmt.Errorf("Failed to close %q: %v", ns.file.Name(), err) + } + ns.closed = true if ns.mounted { if err := unix.Unmount(ns.file.Name(), unix.MNT_DETACH); err != nil { @@ -175,11 +190,17 @@ func (ns *netNS) Close() error { if err := os.RemoveAll(ns.file.Name()); err != nil { return fmt.Errorf("Failed to clean up namespace %s: %v", ns.file.Name(), err) } + ns.mounted = false } + return nil } func (ns *netNS) Do(toRun func(NetNS) error) error { + if err := ns.errorIfClosed(); err != nil { + return err + } + containedCall := func(hostNS NetNS) error { threadNS, err := GetNS(getCurrentThreadNetNSPath()) if err != nil { @@ -218,6 +239,10 @@ func (ns *netNS) Do(toRun func(NetNS) error) error { } func (ns *netNS) Set() error { + if err := ns.errorIfClosed(); err != nil { + return err + } + if _, _, err := unix.Syscall(unix.SYS_SETNS, ns.Fd(), uintptr(unix.CLONE_NEWNET), 0); err != 0 { return fmt.Errorf("Error switching to ns %v: %v", ns.file.Name(), err) } diff --git a/pkg/ns/ns_test.go b/pkg/ns/ns_test.go index 836025e9..de0f3853 100644 --- a/pkg/ns/ns_test.go +++ b/pkg/ns/ns_test.go @@ -170,6 +170,33 @@ var _ = Describe("Linux namespace operations", func() { } }) }) + + Describe("closing a network namespace", func() { + It("should prevent further operations", func() { + createdNetNS, err := ns.NewNS() + Expect(err).NotTo(HaveOccurred()) + + err = createdNetNS.Close() + Expect(err).NotTo(HaveOccurred()) + + err = createdNetNS.Do(func(ns.NetNS) error { return nil }) + Expect(err).To(HaveOccurred()) + + err = createdNetNS.Set() + Expect(err).To(HaveOccurred()) + }) + + It("should only work once", func() { + createdNetNS, err := ns.NewNS() + Expect(err).NotTo(HaveOccurred()) + + err = createdNetNS.Close() + Expect(err).NotTo(HaveOccurred()) + + err = createdNetNS.Close() + Expect(err).To(HaveOccurred()) + }) + }) }) })