diff --git a/devicemodel/hw/pci/virtio/virtio_block.c b/devicemodel/hw/pci/virtio/virtio_block.c index e69524e2e..b368cd6eb 100644 --- a/devicemodel/hw/pci/virtio/virtio_block.c +++ b/devicemodel/hw/pci/virtio/virtio_block.c @@ -31,7 +31,6 @@ #include #include #include -#include #include #include @@ -219,6 +218,15 @@ virtio_blk_done(struct blockif_req *br, int err) pthread_mutex_unlock(&blk->mtx); } +static void +virtio_blk_abort(struct virtio_vq_info *vq, uint16_t idx) +{ + if (idx < vq->qsize) { + vq_relchain(vq, idx, 1); + vq_endchains(vq, 0); + } +} + static void virtio_blk_proc(struct virtio_blk *blk, struct virtio_vq_info *vq) { @@ -231,6 +239,7 @@ virtio_blk_proc(struct virtio_blk *blk, struct virtio_vq_info *vq) struct iovec iov[BLOCKIF_IOV_MAX + 2]; uint16_t idx, flags[BLOCKIF_IOV_MAX + 2]; + idx = vq->qsize; n = vq_getchain(vq, &idx, iov, BLOCKIF_IOV_MAX + 2, flags); /* @@ -241,18 +250,36 @@ virtio_blk_proc(struct virtio_blk *blk, struct virtio_vq_info *vq) * XXX - note - this fails on crash dump, which does a * VIRTIO_BLK_T_FLUSH with a zero transfer length */ - assert(n >= 2 && n <= BLOCKIF_IOV_MAX + 2); + if (n < 2 || n > BLOCKIF_IOV_MAX + 2) { + WPRINTF(("%s: vq_getchain failed\n", __func__)); + virtio_blk_abort(vq, idx); + return; + } io = &blk->ios[idx]; - assert((flags[0] & VRING_DESC_F_WRITE) == 0); - assert(iov[0].iov_len == sizeof(struct virtio_blk_hdr)); + if ((flags[0] & VRING_DESC_F_WRITE) != 0) { + WPRINTF(("%s: the type for hdr should not be VRING_DESC_F_WRITE\n", __func__)); + virtio_blk_abort(vq, idx); + return; + } + if (iov[0].iov_len != sizeof(struct virtio_blk_hdr)) { + WPRINTF(("%s: the size for hdr %ld should be %ld \n", + __func__, + iov[0].iov_len, + sizeof(struct virtio_blk_hdr))); + virtio_blk_abort(vq, idx); + return; + } vbh = iov[0].iov_base; memcpy(&io->req.iov, &iov[1], sizeof(struct iovec) * (n - 2)); io->req.iovcnt = n - 2; io->req.offset = vbh->sector * DEV_BSIZE; io->status = iov[--n].iov_base; - assert(iov[n].iov_len == 1); - assert(flags[n] & VRING_DESC_F_WRITE); + if (iov[n].iov_len != 1 || ((flags[n] & VRING_DESC_F_WRITE) == 0)) { + WPRINTF(("%s: status iov is invalid!\n", __func__)); + virtio_blk_abort(vq, idx); + return; + } /* * XXX @@ -283,7 +310,11 @@ virtio_blk_proc(struct virtio_blk *blk, struct virtio_vq_info *vq) * therefore test the inverse of the descriptor bit * to the op. */ - assert(((flags[i] & VRING_DESC_F_WRITE) == 0) == writeop); + if (((flags[i] & VRING_DESC_F_WRITE) == 0) != writeop) { + WPRINTF(("%s: flag is confict with operation\n", __func__)); + virtio_blk_done(&io->req, EINVAL); + return; + } iolen += iov[i].iov_len; } io->req.resid = iolen; @@ -337,7 +368,8 @@ virtio_blk_proc(struct virtio_blk *blk, struct virtio_vq_info *vq) virtio_blk_done(&io->req, EOPNOTSUPP); return; } - assert(err == 0); + if (err) + WPRINTF(("%s: request process failed\n", __func__)); } static void