vc: modify ioctl function to handle shim test

The kata shim tests make use of an ioctl function, so instead
of having a custom one within that file, use the ioctl
function in utils/utils_linux

Fixes #1419

Signed-off-by: Gabi Beyer <Gabrielle.n.beyer@intel.com>
This commit is contained in:
Gabi Beyer 2019-04-10 13:30:09 -07:00
parent c42507903d
commit b08ab6ae1f
3 changed files with 12 additions and 21 deletions

View File

@ -18,6 +18,7 @@ import (
"unsafe" "unsafe"
. "github.com/kata-containers/runtime/virtcontainers/pkg/mock" . "github.com/kata-containers/runtime/virtcontainers/pkg/mock"
"github.com/kata-containers/runtime/virtcontainers/utils"
) )
const ( const (
@ -281,26 +282,18 @@ func TestKataShimStartWithConsoleNonExistingFailure(t *testing.T) {
testKataShimStart(t, sandbox, params, true) testKataShimStart(t, sandbox, params, true)
} }
func ioctl(fd uintptr, flag, data uintptr) error {
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, flag, data); err != 0 {
return err
}
return nil
}
// unlockpt unlocks the slave pseudoterminal device corresponding to the master pseudoterminal referred to by f. // unlockpt unlocks the slave pseudoterminal device corresponding to the master pseudoterminal referred to by f.
func unlockpt(f *os.File) error { func unlockpt(f *os.File) error {
var u int32 var u int32
return ioctl(f.Fd(), syscall.TIOCSPTLCK, uintptr(unsafe.Pointer(&u))) return utils.Ioctl(f.Fd(), syscall.TIOCSPTLCK, uintptr(unsafe.Pointer(&u)))
} }
// ptsname retrieves the name of the first available pts for the given master. // ptsname retrieves the name of the first available pts for the given master.
func ptsname(f *os.File) (string, error) { func ptsname(f *os.File) (string, error) {
var n int32 var n int32
if err := ioctl(f.Fd(), syscall.TIOCGPTN, uintptr(unsafe.Pointer(&n))); err != nil { if err := utils.Ioctl(f.Fd(), syscall.TIOCGPTN, uintptr(unsafe.Pointer(&n))); err != nil {
return "", err return "", err
} }

View File

@ -20,20 +20,18 @@ import (
// VHOST_VSOCK_SET_GUEST_CID = _IOW(VHOST_VIRTIO, 0x60, __u64) // VHOST_VSOCK_SET_GUEST_CID = _IOW(VHOST_VIRTIO, 0x60, __u64)
const ioctlVhostVsockSetGuestCid = 0x4008AF60 const ioctlVhostVsockSetGuestCid = 0x4008AF60
var ioctlFunc = ioctl var ioctlFunc = Ioctl
// maxUInt represents the maximum valid value for the context ID. // maxUInt represents the maximum valid value for the context ID.
// The upper 32 bits of the CID are reserved and zeroed. // The upper 32 bits of the CID are reserved and zeroed.
// See http://stefanha.github.io/virtio/ // See http://stefanha.github.io/virtio/
var maxUInt uint64 = 1<<32 - 1 var maxUInt uint64 = 1<<32 - 1
func ioctl(fd uintptr, request int, arg1 uint64) error { func Ioctl(fd uintptr, request, data uintptr) error {
if _, _, errno := unix.Syscall( if _, _, errno := unix.Syscall(unix.SYS_IOCTL, fd, request, data); errno != 0 {
unix.SYS_IOCTL, //uintptr(request)
fd, //uintptr(unsafe.Pointer(&arg1)),
uintptr(request), //); errno != 0 {
uintptr(unsafe.Pointer(&arg1)),
); errno != 0 {
return os.NewSyscallError("ioctl", fmt.Errorf("%d", int(errno))) return os.NewSyscallError("ioctl", fmt.Errorf("%d", int(errno)))
} }
@ -75,14 +73,14 @@ func FindContextID() (*os.File, uint64, error) {
// Looking for the first available context ID. // Looking for the first available context ID.
for cid := contextID; cid <= maxUInt; cid++ { for cid := contextID; cid <= maxUInt; cid++ {
if err := ioctlFunc(vsockFd.Fd(), ioctlVhostVsockSetGuestCid, cid); err == nil { if err := ioctlFunc(vsockFd.Fd(), ioctlVhostVsockSetGuestCid, uintptr(unsafe.Pointer(&cid))); err == nil {
return vsockFd, cid, nil return vsockFd, cid, nil
} }
} }
// Last chance to get a free context ID. // Last chance to get a free context ID.
for cid := contextID - 1; cid >= firstContextID; cid-- { for cid := contextID - 1; cid >= firstContextID; cid-- {
if err := ioctlFunc(vsockFd.Fd(), ioctlVhostVsockSetGuestCid, cid); err == nil { if err := ioctlFunc(vsockFd.Fd(), ioctlVhostVsockSetGuestCid, uintptr(unsafe.Pointer(&cid))); err == nil {
return vsockFd, cid, nil return vsockFd, cid, nil
} }
} }

View File

@ -15,7 +15,7 @@ import (
func TestFindContextID(t *testing.T) { func TestFindContextID(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
ioctlFunc = func(fd uintptr, request int, arg1 uint64) error { ioctlFunc = func(fd uintptr, request, arg1 uintptr) error {
return errors.New("ioctl") return errors.New("ioctl")
} }